diff --git a/app/api/tests/data/labeling.invalid.conll b/app/api/tests/data/labeling.invalid.conll index 57c390cd..3e153a0f 100644 --- a/app/api/tests/data/labeling.invalid.conll +++ b/app/api/tests/data/labeling.invalid.conll @@ -1,4 +1,4 @@ -SOCCERO +SOCCERO SOCCERO SOCCERO - O JAPAN B-LOC GET O diff --git a/app/api/tests/test_utils.py b/app/api/tests/test_utils.py index 5d4874b8..30e00063 100644 --- a/app/api/tests/test_utils.py +++ b/app/api/tests/test_utils.py @@ -143,13 +143,16 @@ class TestSeq2seqStorage(TestCase): class TestCoNLLParser(TestCase): def test_calc_char_offset(self): - words = ['EU', 'rejects', 'German', 'call'] - tags = ['B-ORG', 'O', 'B-MISC', 'O'] + f = io.BytesIO() - entities = get_entities(tags) - actual = CoNLLParser.calc_char_offset(words, tags) + s = [ + ("EU", "ORG"), ("rejects", "_"), ("German", "MISC"), ("call", "_") + ] + for w, t in s: + f.write("{}\t{}\n".format(w, t).encode()) + f.seek(0) - self.assertEqual(entities, [('ORG', 0, 0), ('MISC', 2, 2)]) + actual = next(CoNLLParser().parse(f))[0] self.assertEqual(actual, { 'text': 'EU rejects German call', diff --git a/app/api/utils.py b/app/api/utils.py index 2f9fca92..04a50b38 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -6,6 +6,7 @@ import re from collections import defaultdict from random import Random +import conllu from django.db import transaction from django.conf import settings from rest_framework.renderers import JSONRenderer @@ -242,45 +243,51 @@ class CoNLLParser(FileParser): ``` """ def parse(self, file): - words, tags = [], [] data = [] file = io.TextIOWrapper(file, encoding='utf-8') - for i, line in enumerate(file, start=1): - if len(data) >= settings.IMPORT_BATCH_SIZE: - yield data - data = [] - line = line.strip() - if line: - try: - word, tag = line.split('\t') - except ValueError: - raise FileParseException(line_num=i, line=line) + + # Add check exception + + field_parsers = { + "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]), + } + + try: + sentences = conllu.parse( + file.read(), + fields=("form", "ne"), + field_parsers=field_parsers + ) + except conllu.parser.ParseException as e: + raise FileParseException(line_num=-1, line=str(e)) + + for sentence in sentences: + if not sentence: + continue + # if len(data) >= settings.IMPORT_BATCH_SIZE: + # yield data + # data = [] + words, labels = [], [] + for item in sentence: + word = item.get("form") + tag = item.get("ne", None) + + if tag is not None: + char_left = sum(map(lambda x: len(x), words)) + len(words) + char_right = char_left + len(word) + span = [char_left, char_right, tag] + labels.append(span) + words.append(word) - tags.append(tag) - elif words and tags: - j = self.calc_char_offset(words, tags) - data.append(j) - words, tags = [], [] - if len(words) > 0: - j = self.calc_char_offset(words, tags) + + # Create JSONL + j = {'text': ' '.join(words), 'labels': labels} + data.append(j) + if data: yield data - @classmethod - def calc_char_offset(cls, words, tags): - doc = ' '.join(words) - j = {'text': ' '.join(words), 'labels': []} - pos = defaultdict(int) - for label, start_offset, end_offset in get_entities(tags): - entity = ' '.join(words[start_offset: end_offset + 1]) - char_left = doc.index(entity, pos[entity]) - char_right = char_left + len(entity) - span = [char_left, char_right, label] - j['labels'].append(span) - pos[entity] = char_right - return j - class PlainTextParser(FileParser): """Uploads plain text. @@ -373,6 +380,7 @@ class JSONLRenderer(JSONRenderer): ensure_ascii=self.ensure_ascii, allow_nan=not self.strict) + '\n' + class JSONPainter(object): def paint(self, documents): @@ -406,6 +414,7 @@ class JSONPainter(object): data.append(d) return data + class CSVPainter(JSONPainter): def paint(self, documents): diff --git a/requirements.txt b/requirements.txt index af1b6582..841d6b56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ djangorestframework-csv==2.1.0 djangorestframework-filters==0.10.2 environs==4.1.0 djangorestframework-xml==1.4.0 -Faker==0.8.8 +Faker==0.9.1 flake8==3.6.0 furl==2.0.0 gunicorn==19.9.0 @@ -36,3 +36,4 @@ unittest-xml-reporting==2.5.1 vcrpy==2.0.1 vcrpy-unittest==0.1.7 whitenoise[brotli]==4.1.2 +conllu