mirror of https://github.com/doccano/doccano.git
pythonannotation-tooldatasetsactive-learningtext-annotationdatasetnatural-language-processingdata-labelingmachine-learning
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
13 KiB
331 lines
13 KiB
import os
|
|
import pathlib
|
|
import shutil
|
|
|
|
from django.core.files import File
|
|
from django.test import TestCase, override_settings
|
|
from django_drf_filepond.models import StoredUpload, TemporaryUpload
|
|
from django_drf_filepond.utils import _get_file_id
|
|
|
|
from data_import.celery_tasks import import_dataset
|
|
from data_import.pipeline.catalog import RELATION_EXTRACTION
|
|
from examples.models import Example
|
|
from label_types.models import SpanType
|
|
from labels.models import Category, Span
|
|
from projects.models import ProjectType
|
|
from projects.tests.utils import prepare_project
|
|
|
|
|
|
@override_settings(MEDIA_ROOT=os.path.join(os.path.dirname(__file__), "data"))
|
|
class TestImportData(TestCase):
|
|
task = "Any"
|
|
annotation_class = Category
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(self.task)
|
|
self.user = self.project.admin
|
|
self.data_path = pathlib.Path(__file__).parent / "data"
|
|
self.upload_id = _get_file_id()
|
|
|
|
def tearDown(self):
|
|
try:
|
|
su = StoredUpload.objects.get(upload_id=self.upload_id)
|
|
directory = pathlib.Path(su.get_absolute_file_path()).parent
|
|
shutil.rmtree(directory)
|
|
except StoredUpload.DoesNotExist:
|
|
pass
|
|
|
|
def import_dataset(self, filename, file_format, task, kwargs=None):
|
|
file_path = str(self.data_path / filename)
|
|
TemporaryUpload.objects.create(
|
|
upload_id=self.upload_id,
|
|
file_id="1",
|
|
file=File(open(file_path, mode="rb"), filename.split("/")[-1]),
|
|
upload_name=filename,
|
|
upload_type="F",
|
|
)
|
|
upload_ids = [self.upload_id]
|
|
kwargs = kwargs or {}
|
|
return import_dataset(self.user.id, self.project.item.id, file_format, upload_ids, task, **kwargs)
|
|
|
|
|
|
@override_settings(MAX_UPLOAD_SIZE=0)
|
|
class TestMaxFileSize(TestImportData):
|
|
task = ProjectType.DOCUMENT_CLASSIFICATION
|
|
|
|
def test_jsonl(self):
|
|
filename = "text_classification/example.jsonl"
|
|
file_format = "JSONL"
|
|
kwargs = {"column_label": "labels"}
|
|
response = self.import_dataset(filename, file_format, self.task, kwargs)
|
|
self.assertEqual(len(response["error"]), 1)
|
|
self.assertIn("maximum file size", response["error"][0]["message"])
|
|
|
|
|
|
class TestInvalidFileFormat(TestImportData):
|
|
task = ProjectType.DOCUMENT_CLASSIFICATION
|
|
|
|
def test_invalid_file_format(self):
|
|
filename = "text_classification/example.csv"
|
|
file_format = "INVALID_FORMAT"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assertEqual(len(response["error"]), 1)
|
|
|
|
|
|
class TestImportClassificationData(TestImportData):
|
|
task = ProjectType.DOCUMENT_CLASSIFICATION
|
|
|
|
def assert_examples(self, dataset):
|
|
with self.subTest():
|
|
self.assertEqual(Example.objects.count(), len(dataset))
|
|
for text, expected_labels in dataset:
|
|
example = Example.objects.get(text=text)
|
|
labels = set(cat.label.text for cat in example.categories.all())
|
|
self.assertEqual(labels, set(expected_labels))
|
|
|
|
def assert_parse_error(self, response):
|
|
with self.subTest():
|
|
self.assertGreaterEqual(len(response["error"]), 1)
|
|
self.assertEqual(Example.objects.count(), 0)
|
|
self.assertEqual(Category.objects.count(), 0)
|
|
|
|
def test_jsonl(self):
|
|
filename = "text_classification/example.jsonl"
|
|
file_format = "JSONL"
|
|
kwargs = {"column_label": "labels"}
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
|
|
self.import_dataset(filename, file_format, self.task, kwargs)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_csv(self):
|
|
filename = "text_classification/example.csv"
|
|
file_format = "CSV"
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_csv_out_of_order_columns(self):
|
|
filename = "text_classification/example_out_of_order_columns.csv"
|
|
file_format = "CSV"
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_fasttext(self):
|
|
filename = "text_classification/example_fasttext.txt"
|
|
file_format = "fastText"
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_excel(self):
|
|
filename = "text_classification/example.xlsx"
|
|
file_format = "Excel"
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_json(self):
|
|
filename = "text_classification/example.json"
|
|
file_format = "JSON"
|
|
dataset = [("exampleA", ["positive"]), ("exampleB", ["positive", "negative"]), ("exampleC", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_textfile(self):
|
|
filename = "example.txt"
|
|
file_format = "TextFile"
|
|
dataset = [("exampleA\nexampleB\n\nexampleC\n", [])]
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
self.assertEqual(len(response["error"]), 0)
|
|
|
|
def test_textline(self):
|
|
filename = "example.txt"
|
|
file_format = "TextLine"
|
|
dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])]
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
self.assertEqual(len(response["error"]), 1)
|
|
|
|
def test_wrong_jsonl(self):
|
|
filename = "text_classification/example.json"
|
|
file_format = "JSONL"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_parse_error(response)
|
|
|
|
def test_wrong_json(self):
|
|
filename = "text_classification/example.jsonl"
|
|
file_format = "JSON"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_parse_error(response)
|
|
|
|
def test_wrong_excel(self):
|
|
filename = "text_classification/example.jsonl"
|
|
file_format = "Excel"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_parse_error(response)
|
|
|
|
def test_wrong_csv(self):
|
|
filename = "text_classification/example.jsonl"
|
|
file_format = "CSV"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_parse_error(response)
|
|
|
|
|
|
class TestImportSequenceLabelingData(TestImportData):
|
|
task = ProjectType.SEQUENCE_LABELING
|
|
|
|
def assert_examples(self, dataset):
|
|
self.assertEqual(Example.objects.count(), len(dataset))
|
|
for text, expected_labels in dataset:
|
|
example = Example.objects.get(text=text)
|
|
labels = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
|
|
self.assertEqual(labels, expected_labels)
|
|
|
|
def assert_parse_error(self, response):
|
|
self.assertGreaterEqual(len(response["error"]), 1)
|
|
self.assertEqual(Example.objects.count(), 0)
|
|
self.assertEqual(SpanType.objects.count(), 0)
|
|
self.assertEqual(Span.objects.count(), 0)
|
|
|
|
def test_jsonl(self):
|
|
filename = "sequence_labeling/example.jsonl"
|
|
file_format = "JSONL"
|
|
dataset = [("exampleA", [[0, 1, "LOC"]]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_conll(self):
|
|
filename = "sequence_labeling/example.conll"
|
|
file_format = "CoNLL"
|
|
dataset = [("JAPAN GET", [[0, 5, "LOC"]]), ("Nadim Ladki", [[0, 11, "PER"]])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_wrong_conll(self):
|
|
filename = "sequence_labeling/example.jsonl"
|
|
file_format = "CoNLL"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assert_parse_error(response)
|
|
|
|
def test_jsonl_with_overlapping(self):
|
|
filename = "sequence_labeling/example_overlapping.jsonl"
|
|
file_format = "JSONL"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assertEqual(len(response["error"]), 0)
|
|
|
|
|
|
class TestImportRelationExtractionData(TestImportData):
|
|
task = ProjectType.SEQUENCE_LABELING
|
|
|
|
def setUp(self):
|
|
self.project = prepare_project(self.task, use_relation=True)
|
|
self.user = self.project.admin
|
|
self.data_path = pathlib.Path(__file__).parent / "data"
|
|
self.upload_id = _get_file_id()
|
|
|
|
def assert_examples(self, dataset):
|
|
self.assertEqual(Example.objects.count(), len(dataset))
|
|
for text, expected_spans in dataset:
|
|
example = Example.objects.get(text=text)
|
|
spans = [[span.start_offset, span.end_offset, span.label.text] for span in example.spans.all()]
|
|
self.assertEqual(spans, expected_spans)
|
|
self.assertEqual(example.relations.count(), 3)
|
|
|
|
def assert_parse_error(self, response):
|
|
self.assertGreaterEqual(len(response["error"]), 1)
|
|
self.assertEqual(Example.objects.count(), 0)
|
|
self.assertEqual(SpanType.objects.count(), 0)
|
|
self.assertEqual(Span.objects.count(), 0)
|
|
|
|
def test_jsonl(self):
|
|
filename = "relation_extraction/example.jsonl"
|
|
file_format = "JSONL"
|
|
dataset = [
|
|
(
|
|
"Google was founded on September 4, 1998, by Larry Page and Sergey Brin.",
|
|
[[0, 6, "ORG"], [22, 39, "DATE"], [44, 54, "PERSON"], [59, 70, "PERSON"]],
|
|
),
|
|
]
|
|
self.import_dataset(filename, file_format, RELATION_EXTRACTION)
|
|
self.assert_examples(dataset)
|
|
|
|
|
|
class TestImportSeq2seqData(TestImportData):
|
|
task = ProjectType.SEQ2SEQ
|
|
|
|
def assert_examples(self, dataset):
|
|
self.assertEqual(Example.objects.count(), len(dataset))
|
|
for text, expected_labels in dataset:
|
|
example = Example.objects.get(text=text)
|
|
labels = set(text_label.text for text_label in example.texts.all())
|
|
self.assertEqual(labels, set(expected_labels))
|
|
|
|
def test_jsonl(self):
|
|
filename = "seq2seq/example.jsonl"
|
|
file_format = "JSONL"
|
|
dataset = [("exampleA", ["label1"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_json(self):
|
|
filename = "seq2seq/example.json"
|
|
file_format = "JSON"
|
|
dataset = [("exampleA", ["label1"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
def test_csv(self):
|
|
filename = "seq2seq/example.csv"
|
|
file_format = "CSV"
|
|
dataset = [("exampleA", ["label1"]), ("exampleB", [])]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
|
|
class TestImportIntentDetectionAndSlotFillingData(TestImportData):
|
|
task = ProjectType.INTENT_DETECTION_AND_SLOT_FILLING
|
|
|
|
def assert_examples(self, dataset):
|
|
self.assertEqual(Example.objects.count(), len(dataset))
|
|
for text, expected_labels in dataset:
|
|
example = Example.objects.get(text=text)
|
|
cats = set(cat.label.text for cat in example.categories.all())
|
|
entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
|
|
self.assertEqual(cats, set(expected_labels["cats"]))
|
|
self.assertEqual(entities, expected_labels["entities"])
|
|
|
|
def test_entities_and_cats(self):
|
|
filename = "intent/example.jsonl"
|
|
file_format = "JSONL"
|
|
dataset = [
|
|
("exampleA", {"cats": ["positive"], "entities": [(0, 1, "LOC")]}),
|
|
("exampleB", {"cats": ["positive"], "entities": []}),
|
|
("exampleC", {"cats": [], "entities": [(0, 1, "LOC")]}),
|
|
("exampleD", {"cats": [], "entities": []}),
|
|
]
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assert_examples(dataset)
|
|
|
|
|
|
class TestImportImageClassificationData(TestImportData):
|
|
task = ProjectType.IMAGE_CLASSIFICATION
|
|
|
|
def test_example(self):
|
|
filename = "images/1500x500.jpeg"
|
|
file_format = "ImageFile"
|
|
self.import_dataset(filename, file_format, self.task)
|
|
self.assertEqual(Example.objects.count(), 1)
|
|
|
|
|
|
@override_settings(ENABLE_FILE_TYPE_CHECK=True)
|
|
class TestFileTypeChecking(TestImportData):
|
|
task = ProjectType.IMAGE_CLASSIFICATION
|
|
|
|
def test_example(self):
|
|
filename = "images/example.ico"
|
|
file_format = "ImageFile"
|
|
response = self.import_dataset(filename, file_format, self.task)
|
|
self.assertEqual(len(response["error"]), 1)
|
|
self.assertIn("unexpected", response["error"][0]["message"])
|