Browse Source

Add Import-Export-Integrity-Tests

pull/1300/head
Chris 3 years ago
parent
commit
3a342633a3
2 changed files with 99 additions and 2 deletions
  1. 2
      app/api/tests/data/example.csv
  2. 99
      app/api/tests/test_api.py

2
app/api/tests/data/example.csv

@ -1,5 +1,5 @@
text,label,meta text,label,meta
AAA
AAA,,
BBB,Positive,The following is meta data BBB,Positive,The following is meta data
CCC,Negative CCC,Negative
DDD,,This is meta data DDD,,This is meta data

99
app/api/tests/test_api.py

@ -1,4 +1,5 @@
import os import os
from io import BytesIO
from django.conf import settings from django.conf import settings
from django.test import override_settings from django.test import override_settings
@ -9,10 +10,11 @@ from rest_framework.test import APITestCase
from ..exceptions import FileParseException from ..exceptions import FileParseException
from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
SPEECH2TEXT, Comment, Document, Role, RoleMapping,
SPEECH2TEXT, Document, Role, RoleMapping,
SequenceAnnotation, User) SequenceAnnotation, User)
from ..utils import (CoNLLParser, CSVParser, FastTextParser, JSONParser, from ..utils import (CoNLLParser, CSVParser, FastTextParser, JSONParser,
PlainTextParser) PlainTextParser)
from ..views import TextUploadAPI
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
@ -1225,6 +1227,101 @@ class TestFilter(APITestCase):
def doCleanups(cls): def doCleanups(cls):
remove_all_role_mappings() remove_all_role_mappings()
class TestImportExportIntegrity(APITestCase):
"""Tests that check for equality between imported and exported data of a file. """
@classmethod
def setUpTestData(cls):
cls.super_user_name = 'super_user_name'
cls.super_user_pass = 'super_user_pass'
create_default_roles()
super_user = User.objects.create_superuser(username=cls.super_user_name,
password=cls.super_user_pass,
email='fizz@buzz.com')
cls.classification_project = mommy.make('TextClassificationProject',
users=[super_user], project_type=DOCUMENT_CLASSIFICATION)
cls.classification_upload_url = reverse(viewname='doc_uploader', args=[cls.classification_project.id])
cls.classification_download_url = reverse(viewname='doc_downloader', args=[cls.classification_project.id])
assign_user_to_role(project_member=super_user, project=cls.classification_project,
role_name=settings.ROLE_PROJECT_ADMIN)
def setUp(self):
self.client.login(username=self.super_user_name,
password=self.super_user_pass)
def load_test_helper(self, upload_url, download_url, filename, import_format, export_format, response_format="text/csv; charset=utf-8",**kwargs):
parser = TextUploadAPI.select_parser(import_format)
with open(os.path.join(DATA_DIR, filename), 'rb') as f:
self.client.post(upload_url, data={'file': f, 'format': import_format})
f.seek(0)
imported = parser.parse(f)
import_data = [elem for elem in [x for elem in imported for x in elem] if elem.get('labels')]
r = self.client.get(download_url, data={'q': export_format}, HTTP_ACCEPT=response_format)
b = BytesIO(r.content)
if export_format == 'txt':
export_format = import_format
parser = TextUploadAPI.select_parser(export_format)
exported = parser.parse(b)
exported_data = [x for elem in exported for x in elem if x.get('labels') or x.get('annotations')]
self.assertTrue(len(import_data) == len(exported_data), 'Length of imported dataset does not match exported')
return import_data, exported_data
# Classification
def test_jsonl_classification_import_export_integrity(self):
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url,
download_url=self.classification_download_url,
filename='classification.jsonl',
import_format='json',
project=self.classification_project,
export_format='json',
response_format='application/json')
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()}
for im, ex in zip(import_data, export_data):
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.')
ex_labels = set(label_mapping[int(x.get('label'))] for x in ex.get('annotations', []))
self.assertFalse(set(im.get('labels')).symmetric_difference(ex_labels), 'Integritycheck failed. Labels differ.')
def test_csv_classification_import_export_integrity(self):
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url,
download_url=self.classification_download_url,
filename='example.csv',
import_format='csv',
project=self.classification_project,
export_format='csv')
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()}
for im, ex in zip(import_data, export_data):
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.')
ex_labels = [label_mapping[int(elem)] for elem in ex.get('labels', [])]
self.assertTrue(im.get('labels') == ex_labels, 'Integritycheck failed. Labels differ.')
def test_xlsx_classification_import_export_integrity(self):
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url,
download_url=self.classification_download_url,
filename='example.xlsx',
import_format='excel',
project=self.classification_project,
export_format='csv')
label_mapping = {label.id: label.text for label in self.classification_project.labels.all()}
for im, ex in zip(import_data, export_data):
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.')
ex_labels = [label_mapping[int(elem)] for elem in ex.get('labels', [])]
self.assertTrue(im.get('labels') == ex_labels, 'Integritycheck failed. Labels differ.')
def test_fasttext_classification_import_export_integrity(self):
import_data, export_data = self.load_test_helper(upload_url=self.classification_upload_url,
download_url=self.classification_download_url,
filename='example_fasttext.txt',
import_format='fastText',
project=self.classification_project,
export_format='txt')
for im, ex in zip(import_data, export_data):
self.assertTrue(im['text'] == ex['text'], 'Integritycheck failed. Dataset texts do not match.')
self.assertFalse(set(im.get('labels')).symmetric_difference(ex.get('labels')), 'Integritycheck failed. Labels differ.')
@classmethod
def doCleanups(cls):
remove_all_role_mappings()
class TestUploader(APITestCase): class TestUploader(APITestCase):

Loading…
Cancel
Save