|
|
@ -1,4 +1,5 @@ |
|
|
|
import os |
|
|
|
from io import BytesIO |
|
|
|
|
|
|
|
from django.conf import settings |
|
|
|
from django.test import override_settings |
|
|
@ -9,10 +10,11 @@ from rest_framework.test import APITestCase |
|
|
|
|
|
|
|
from ..exceptions import FileParseException |
|
|
|
from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, |
|
|
|
SPEECH2TEXT, Comment, Document, Role, RoleMapping, |
|
|
|
SPEECH2TEXT, Document, Role, RoleMapping, |
|
|
|
SequenceAnnotation, User) |
|
|
|
from ..utils import (CoNLLParser, CSVParser, FastTextParser, JSONParser, |
|
|
|
PlainTextParser) |
|
|
|
from ..views import TextUploadAPI |
|
|
|
|
|
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
|
|
|
|
|
|
@ -1225,6 +1227,101 @@ class TestFilter(APITestCase): |
|
|
|
def doCleanups(cls): |
|
|
|
remove_all_role_mappings() |
|
|
|
|
|
|
|
class TestImportExportIntegrity(APITestCase): |
|
|
|
"""Tests that check for equality between imported and exported data of a file. """ |
|
|
|
@classmethod |
|
|
|
def setUpTestData(cls): |
|
|
|
cls.super_user_name = 'super_user_name' |
|
|
|
cls.super_user_pass = 'super_user_pass' |
|
|
|
create_default_roles() |
|
|
|
super_user = User.objects.create_superuser(username=cls.super_user_name, |
|
|
|
password=cls.super_user_pass, |
|
|
|
email='fizz@buzz.com') |
|
|
|
cls.classification_project = mommy.make('TextClassificationProject', |
|
|
|
users=[super_user], project_type=DOCUMENT_CLASSIFICATION) |
|
|
|
|
|
|
|
cls.classification_upload_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id]) |
|
|
|
cls.classification_download_url = reverse(viewname='doc_downloader', args=[cls.classification_project.id]) |
|
|
|
assign_user_to_role(project_member=super_user, project=cls.classification_project, |
|
|
|
role_name=settings.ROLE_PROJECT_ADMIN) |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.client.login(username=self.super_user_name, |
|
|
|
password=self.super_user_pass) |
|
|
|
|
|
|
|
def load_test_helper(self, upload_url, download_url, filename, import_format, export_format, response_format="text/csv; charset=utf-8",**kwargs): |
|
|
|
parser = TextUploadAPI.select_parser(import_format) |
|
|
|
with open(os.path.join(DATA_DIR, filename), 'rb') as f: |
|
|
|
self.client.post(upload_url, data={'file': f, 'format': import_format}) |
|
|
|
f.seek(0) |
|
|
|
imported = parser.parse(f) |
|
|
|
import_data = [elem for elem in [x for elem in imported for x in elem] if elem.get('labels')] |
|
|
|
|
|
|
|
r = self.client.get(download_url, data={'q': export_format}, HTTP_ACCEPT=response_format) |
|
|
|
b = BytesIO(r.content) |
|
|
|
if export_format == 'txt': |
|
|
|
export_format = import_format |
|
|
|
parser = TextUploadAPI.select_parser(export_format) |
|
|
|
exported = parser.parse(b) |
|
|
|
exported_data = [x for elem in exported for x in elem if x.get('labels') or x.get('annotations')] |
|
|
|
self.assertTrue(len(import_data) == len(exported_data), 'Length of imported dataset does not match exported') |
|
|
|
return import_data, exported_data |
|
|
|
|
|
|
|
# Classification |
|
|
|
def test_jsonl_classification_import_export_integrity(self): |
|
|
|
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url, |
|
|
|
download_url=self.classification_download_url, |
|
|
|
filename='classification.jsonl', |
|
|
|
import_format='json', |
|
|
|
project=self.classification_project, |
|
|
|
export_format='json', |
|
|
|
response_format='application/json') |
|
|
|
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()} |
|
|
|
for im, ex in zip(import_data, export_data): |
|
|
|
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.') |
|
|
|
ex_labels = set(label_mapping[int(x.get('label'))] for x in ex.get('annotations', [])) |
|
|
|
self.assertFalse(set(im.get('labels')).symmetric_difference(ex_labels), 'Integritycheck failed. Labels differ.') |
|
|
|
|
|
|
|
def test_csv_classification_import_export_integrity(self): |
|
|
|
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url, |
|
|
|
download_url=self.classification_download_url, |
|
|
|
filename='example.csv', |
|
|
|
import_format='csv', |
|
|
|
project=self.classification_project, |
|
|
|
export_format='csv') |
|
|
|
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()} |
|
|
|
for im, ex in zip(import_data, export_data): |
|
|
|
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.') |
|
|
|
ex_labels = [label_mapping[int(elem)] for elem in ex.get('labels', [])] |
|
|
|
self.assertTrue(im.get('labels') == ex_labels, 'Integritycheck failed. Labels differ.') |
|
|
|
|
|
|
|
def test_xlsx_classification_import_export_integrity(self): |
|
|
|
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url, |
|
|
|
download_url=self.classification_download_url, |
|
|
|
filename='example.xlsx', |
|
|
|
import_format='excel', |
|
|
|
project=self.classification_project, |
|
|
|
export_format='csv') |
|
|
|
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()} |
|
|
|
for im, ex in zip(import_data, export_data): |
|
|
|
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.') |
|
|
|
ex_labels = [label_mapping[int(elem)] for elem in ex.get('labels', [])] |
|
|
|
self.assertTrue(im.get('labels') == ex_labels, 'Integritycheck failed. Labels differ.') |
|
|
|
|
|
|
|
def test_fasttext_classification_import_export_integrity(self): |
|
|
|
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url, |
|
|
|
download_url=self.classification_download_url, |
|
|
|
filename='example_fasttext.txt', |
|
|
|
import_format='fastText', |
|
|
|
project=self.classification_project, |
|
|
|
export_format='txt') |
|
|
|
for im, ex in zip(import_data, export_data): |
|
|
|
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.') |
|
|
|
self.assertFalse(set(im.get('labels')).symmetric_difference(ex.get('labels')), 'Integritycheck failed. Labels differ.') |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def doCleanups(cls): |
|
|
|
remove_all_role_mappings() |
|
|
|
|
|
|
|
class TestUploader(APITestCase): |
|
|
|
|
|
|
|