diff --git a/app/api/tests/data/labeling.invalid.conll b/app/api/tests/data/labeling.invalid.conll index 57c390cd..3e153a0f 100644 --- a/app/api/tests/data/labeling.invalid.conll +++ b/app/api/tests/data/labeling.invalid.conll @@ -1,4 +1,4 @@ -SOCCERO +SOCCERO SOCCERO SOCCERO - O JAPAN B-LOC GET O diff --git a/app/api/tests/test_utils.py b/app/api/tests/test_utils.py index 5d4874b8..34d0e9ac 100644 --- a/app/api/tests/test_utils.py +++ b/app/api/tests/test_utils.py @@ -143,13 +143,14 @@ class TestSeq2seqStorage(TestCase): class TestCoNLLParser(TestCase): def test_calc_char_offset(self): - words = ['EU', 'rejects', 'German', 'call'] - tags = ['B-ORG', 'O', 'B-MISC', 'O'] - - entities = get_entities(tags) - actual = CoNLLParser.calc_char_offset(words, tags) - - self.assertEqual(entities, [('ORG', 0, 0), ('MISC', 2, 2)]) + f = io.BytesIO( + b"EU\tORG\n" + b"rejects\t_\n" + b"German\tMISC\n" + b"call\t_\n" + ) + + actual = next(CoNLLParser().parse(f))[0] self.assertEqual(actual, { 'text': 'EU rejects German call', diff --git a/app/api/utils.py b/app/api/utils.py index d158eb79..2ecda302 100644 --- a/app/api/utils.py +++ b/app/api/utils.py @@ -6,6 +6,7 @@ import re from collections import defaultdict from random import Random +import conllu from django.db import transaction from django.conf import settings from rest_framework.renderers import JSONRenderer @@ -242,45 +243,50 @@ class CoNLLParser(FileParser): ``` """ def parse(self, file): - words, tags = [], [] data = [] file = io.TextIOWrapper(file, encoding='utf-8') - for i, line in enumerate(file, start=1): - if len(data) >= settings.IMPORT_BATCH_SIZE: - yield data - data = [] - line = line.strip() - if line: - try: - word, tag = line.split('\t') - except ValueError: - raise FileParseException(line_num=i, line=line) - words.append(word) - tags.append(tag) - elif words and tags: - j = self.calc_char_offset(words, tags) - data.append(j) - words, tags = [], [] - if len(words) > 0: - j = self.calc_char_offset(words, tags) - data.append(j) + + # Add check exception + + field_parsers = { + "ne": lambda line, i: conllu.parser.parse_nullable_value(line[i]), + } + + gen_parser = conllu.parse_incr( + file, + fields=("form", "ne"), + field_parsers=field_parsers + ) + + try: + for sentence in gen_parser: + if not sentence: + continue + if len(data) >= settings.IMPORT_BATCH_SIZE: + yield data + data = [] + words, labels = [], [] + for item in sentence: + word = item.get("form") + tag = item.get("ne") + + if tag is not None: + char_left = sum(map(len, words)) + len(words) + char_right = char_left + len(word) + span = [char_left, char_right, tag] + labels.append(span) + + words.append(word) + + # Create and add JSONL + data.append({'text': ' '.join(words), 'labels': labels}) + + except conllu.parser.ParseException as e: + raise FileParseException(line_num=-1, line=str(e)) + if data: yield data - @classmethod - def calc_char_offset(cls, words, tags): - doc = ' '.join(words) - j = {'text': ' '.join(words), 'labels': []} - pos = defaultdict(int) - for label, start_offset, end_offset in get_entities(tags): - entity = ' '.join(words[start_offset: end_offset + 1]) - char_left = doc.index(entity, pos[entity]) - char_right = char_left + len(entity) - span = [char_left, char_right, label] - j['labels'].append(span) - pos[entity] = char_right - return j - class PlainTextParser(FileParser): """Uploads plain text. diff --git a/requirements.txt b/requirements.txt index 170f6c11..bc603cc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ djangorestframework-csv==2.1.0 djangorestframework-filters==0.10.2 environs==4.1.0 djangorestframework-xml==1.4.0 -Faker==0.8.8 +Faker==0.9.1 flake8==3.6.0 furl==2.0.0 gunicorn==19.9.0 @@ -37,3 +37,4 @@ unittest-xml-reporting==2.5.1 vcrpy==2.0.1 vcrpy-unittest==0.1.7 whitenoise[brotli]==4.1.2 +conllu==1.3.2