diff --git a/app/api/tests/upload/test_conll.py b/app/api/tests/upload/test_conll.py new file mode 100644 index 00000000..964ee571 --- /dev/null +++ b/app/api/tests/upload/test_conll.py @@ -0,0 +1,48 @@ +import os +import shutil +import tempfile +import unittest + +from ...views.upload.data import TextData +from ...views.upload.dataset import CoNLLDataset +from ...views.upload.label import OffsetLabel + + +class TestCoNLLDataset(unittest.TestCase): + + def setUp(self): + self.test_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.test_dir, 'test_file.txt') + self.content = """EU\tB-ORG +rejects\tO +German\tB-MISC +call\tO +to\tO +boycott\tO +British\tB-MISC +lamb\tO +.\tO + +Peter\tB-PER +Blackburn\tI-PER + +""" + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def create_file(self, encoding=None): + with open(self.test_file, 'w', encoding=encoding) as f: + f.write(self.content) + + def test_can_load(self): + self.create_file() + dataset = CoNLLDataset( + filenames=[self.test_file], + label_class=OffsetLabel, + data_class=TextData + ) + it = dataset.load(self.test_file) + record = next(it) + expected = 'EU rejects German call to boycott British lamb .' + self.assertEqual(record.data['text'], expected) diff --git a/app/api/views/upload/catalog.py b/app/api/views/upload/catalog.py index 9ebe2f4a..dbff9e40 100644 --- a/app/api/views/upload/catalog.py +++ b/app/api/views/upload/catalog.py @@ -75,6 +75,7 @@ class OptionNone(BaseModel): class OptionCoNLL(BaseModel): scheme: Literal['IOB2', 'IOE2', 'IOBES', 'BILOU'] = 'IOB2' + delimiter: Literal[' ', ''] = ' ' class Options: @@ -113,7 +114,7 @@ Options.register(DOCUMENT_CLASSIFICATION, Excel, OptionColumn, examples.Category Options.register(SEQUENCE_LABELING, TextFile, OptionNone, examples.Generic_TextFile) Options.register(SEQUENCE_LABELING, TextLine, OptionNone, examples.Generic_TextLine) Options.register(SEQUENCE_LABELING, JSONL, OptionColumn, examples.Offset_JSONL) -Options.register(SEQUENCE_LABELING, CoNLL, OptionNone, examples.Offset_CoNLL) +Options.register(SEQUENCE_LABELING, CoNLL, OptionCoNLL, examples.Offset_CoNLL) # Sequence to sequence Options.register(SEQ2SEQ, TextFile, OptionNone, examples.Generic_TextFile) diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index 8a813094..4cf75207 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -3,6 +3,7 @@ import json from typing import Dict, Iterator, List, Optional, Type import pyexcel +from seqeval.scheme import IOB2, IOE2, IOBES, BILOU, Tokens from .data import BaseData from .exception import FileParseException @@ -72,7 +73,7 @@ class Dataset: if column_data not in row: message = f'{column_data} does not exist.' raise FileParseException(filename, line_num, message) - text = row.pop(self.kwargs.get('column_data', 'text')) + text = row.pop(column_data) label = row.pop(self.kwargs.get('column_label', 'label'), []) label = [label] if isinstance(label, str) else label label = [self.label_class.parse(o) for o in label] @@ -174,8 +175,48 @@ class FastTextDataset(Dataset): yield record -class ConllDataset(Dataset): +class CoNLLDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: - pass + words, tags = [], [] + delimiter = self.kwargs.get('delimiter', ' ') + for line_num, line in enumerate(f, start=1): + line = line.rstrip() + if line: + tokens = line.split('\t') + if len(tokens) != 2: + message = 'A line must be separated by tab and has two columns.' + raise FileParseException(filename, line_num, message) + word, tag = tokens + words.append(word) + tags.append(tag) + else: + text = delimiter.join(words) + data = self.data_class.parse(filename=filename, text=text) + labels = self.get_label(words, tags, delimiter) + record = Record(data=data, label=labels) + yield record + words, tags = [], [] + + def get_scheme(self, scheme: str): + mapping = { + 'IOB2': IOB2, + 'IOE2': IOE2, + 'IOBES': IOBES, + 'BILOU': BILOU + } + return mapping[scheme] + + def get_label(self, words: List[str], tags: List[str], delimiter: str) -> List[Label]: + scheme = self.get_scheme(self.kwargs.get('scheme', 'IOB2')) + tokens = Tokens(tags, scheme) + labels = [] + for entity in tokens.entities: + text = delimiter.join(words[:entity.start]) + start = len(text) + len(delimiter) if text else len(text) + chunk = words[entity.start: entity.end] + text = delimiter.join(chunk) + end = start + len(text) + labels.append(self.label_class.parse((start, end, entity.tag))) + return labels diff --git a/app/api/views/upload/factory.py b/app/api/views/upload/factory.py index 2e959f24..22d1ec02 100644 --- a/app/api/views/upload/factory.py +++ b/app/api/views/upload/factory.py @@ -18,7 +18,8 @@ def get_dataset_class(format: str): catalog.JSONL.name: dataset.JSONLDataset, catalog.JSON.name: dataset.JSONDataset, catalog.FastText.name: dataset.FastTextDataset, - catalog.Excel.name: dataset.ExcelDataset + catalog.Excel.name: dataset.ExcelDataset, + catalog.CoNLL.name: dataset.CoNLLDataset } if format not in mapping: ValueError(f'Invalid format: {format}') diff --git a/app/api/views/upload/label.py b/app/api/views/upload/label.py index c3473a8b..8d28f420 100644 --- a/app/api/views/upload/label.py +++ b/app/api/views/upload/label.py @@ -62,8 +62,8 @@ class OffsetLabel(Label): @classmethod def parse(cls, obj: Any): - if isinstance(obj, list): - columns = ['label', 'start_offset', 'end_offset'] + if isinstance(obj, list) or isinstance(obj, tuple): + columns = ['start_offset', 'end_offset', 'label'] obj = zip(columns, obj) return cls.parse_obj(obj) elif isinstance(obj, dict): diff --git a/frontend/pages/projects/_id/upload/index.vue b/frontend/pages/projects/_id/upload/index.vue index 62a30197..4a05104e 100644 --- a/frontend/pages/projects/_id/upload/index.vue +++ b/frontend/pages/projects/_id/upload/index.vue @@ -234,6 +234,8 @@ export default { return 'Tab' } else if (text === ' ') { return 'Space' + } else if (text === '') { + return 'None' } else { return text }