diff --git a/backend/api/tests/data/seq2seq/example.csv b/backend/api/tests/data/seq2seq/example.csv new file mode 100644 index 00000000..d8ccc403 --- /dev/null +++ b/backend/api/tests/data/seq2seq/example.csv @@ -0,0 +1,5 @@ +text,label,meta +AAA,LabelA, +BBB,LabelB,MetaB +CCC,LabelC +DDD,,MetaD diff --git a/backend/api/tests/data/seq2seq/example.json b/backend/api/tests/data/seq2seq/example.json new file mode 100644 index 00000000..dfac44b9 --- /dev/null +++ b/backend/api/tests/data/seq2seq/example.json @@ -0,0 +1,5 @@ +[ + {"text": "example", "label": ["example1", "example2"]}, + {"text": "example", "label": ["example"]}, + {"text": "example", "label": ["example"]} +] diff --git a/backend/api/tests/data/seq2seq/example.jsonl b/backend/api/tests/data/seq2seq/example.jsonl index a43d34f5..2069686e 100644 --- a/backend/api/tests/data/seq2seq/example.jsonl +++ b/backend/api/tests/data/seq2seq/example.jsonl @@ -1,3 +1,3 @@ -{"text": "example", "labels": ["example1", "example2"], "meta": {"wikiPageID": 1}} -{"text": "example", "labels": ["example"], "meta": {"wikiPageID": 2}} -{"text": "example", "labels": ["example"], "meta": {"wikiPageID": 3}} +{"text": "example", "label": ["example1", "example2"], "meta": {"wikiPageID": 1}} +{"text": "example", "label": ["example"], "meta": {"wikiPageID": 2}} +{"text": "example", "label": ["example"], "meta": {"wikiPageID": 3}} diff --git a/backend/api/tests/test_tasks.py b/backend/api/tests/test_tasks.py index 40011bd3..db7ea29d 100644 --- a/backend/api/tests/test_tasks.py +++ b/backend/api/tests/test_tasks.py @@ -2,7 +2,8 @@ import pathlib from django.test import TestCase -from ..models import Category, Example, Label, Span +from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, + Category, Example, Label, Span, TextLabel) from ..tasks import injest_data from .api.utils import prepare_project @@ -32,7 +33,7 @@ class TestIngestData(TestCase): class TestIngestClassificationData(TestIngestData): - task = 'DocumentClassification' + task = DOCUMENT_CLASSIFICATION annotation_class = Category def test_jsonl(self): @@ -73,7 +74,7 @@ class TestIngestClassificationData(TestIngestData): class TestIngestSequenceLabelingData(TestIngestData): - task = 'SequenceLabeling' + task = SEQUENCE_LABELING annotation_class = Span def test_jsonl(self): @@ -85,3 +86,23 @@ class TestIngestSequenceLabelingData(TestIngestData): filename = 'sequence_labeling/example.conll' file_format = 'CoNLL' self.assert_count(filename, file_format, expected_example=3, expected_label=2, expected_annotation=5) + + +class TestIngestSeq2seqData(TestIngestData): + task = SEQ2SEQ + annotation_class = TextLabel + + def test_jsonl(self): + filename = 'seq2seq/example.jsonl' + file_format = 'JSONL' + self.assert_count(filename, file_format, expected_example=3, expected_label=0, expected_annotation=4) + + def test_json(self): + filename = 'seq2seq/example.json' + file_format = 'JSON' + self.assert_count(filename, file_format, expected_example=3, expected_label=0, expected_annotation=4) + + def test_csv(self): + filename = 'seq2seq/example.csv' + file_format = 'CSV' + self.assert_count(filename, file_format, expected_example=4, expected_label=0, expected_annotation=3) diff --git a/backend/api/views/upload/dataset.py b/backend/api/views/upload/dataset.py index 3e234e12..477eb5a7 100644 --- a/backend/api/views/upload/dataset.py +++ b/backend/api/views/upload/dataset.py @@ -112,7 +112,7 @@ class Dataset: label = [label] if isinstance(label, str) else label try: label = [self.label_class.parse(o) for o in label] - except pydantic.error_wrappers.ValidationError: + except (pydantic.error_wrappers.ValidationError, TypeError): label = [] data = self.data_class.parse(text=text, filename=filename, meta=row) record = Record(data=data, label=label) diff --git a/backend/api/views/upload/label.py b/backend/api/views/upload/label.py index 8d28f420..7fb220ae 100644 --- a/backend/api/views/upload/label.py +++ b/backend/api/views/upload/label.py @@ -92,10 +92,10 @@ class TextLabel(Label): @classmethod def parse(cls, obj: Any): - if isinstance(obj, str): + if isinstance(obj, str) and obj: return cls(text=obj) else: - raise TypeError(f'{obj} is not str.') + raise TypeError(f'{obj} is not str or empty.') def replace(self, mapping: Dict[str, str]) -> 'Label': return self