diff --git a/backend/api/tasks.py b/backend/api/tasks.py index b41b7bde..9be04894 100644 --- a/backend/api/tasks.py +++ b/backend/api/tasks.py @@ -7,7 +7,8 @@ from django.shortcuts import get_object_or_404 from .models import Project from .views.download.factory import create_repository, create_writer from .views.download.service import ExportApplicationService -from .views.upload.factories import create_bulder, create_cleaner, create_parser +from .views.upload.factories import (create_bulder, create_cleaner, + create_parser) from .views.upload.readers import Reader from .views.upload.writers import BulkWriter diff --git a/backend/api/tests/upload/test_builder.py b/backend/api/tests/upload/test_builder.py index a7b9a645..98a8fbd4 100644 --- a/backend/api/tests/upload/test_builder.py +++ b/backend/api/tests/upload/test_builder.py @@ -1,7 +1,9 @@ import unittest +from typing import List from ...views.upload import builders from ...views.upload.data import TextData +from ...views.upload.exception import FileParseException from ...views.upload.label import CategoryLabel @@ -11,26 +13,55 @@ class TestColumnBuilder(unittest.TestCase): self.assertEqual(actual.data.text, expected['data']) self.assertEqual(actual.label, expected['label']) - def test_can_load_default_column_names(self): - row = {'text': 'Text', 'label': 'Label'} - data_column = builders.DataColumn('text', TextData) - label_columns = [builders.LabelColumn('label', CategoryLabel)] + def create_record(self, row, data_column: builders.DataColumn, label_columns: List[builders.LabelColumn]): builder = builders.ColumnBuilder( data_column=data_column, label_columns=label_columns ) - actual = builder.build(row, filename='', line_num=1) + return builder.build(row, filename='', line_num=1) + + def test_can_load_default_column_names(self): + row = {'text': 'Text', 'label': 'Label'} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] + actual = self.create_record(row, data_column, label_columns) expected = {'data': 'Text', 'label': [{'text': 'Label'}]} self.assert_record(actual, expected) + def test_can_specify_any_column_names(self): + row = {'body': 'Text', 'star': 5} + data_column = builders.DataColumn('body', TextData) + label_columns = [builders.LabelColumn('star', CategoryLabel)] + actual = self.create_record(row, data_column, label_columns) + expected = {'data': 'Text', 'label': [{'text': '5'}]} + self.assert_record(actual, expected) + def test_can_load_only_text_column(self): row = {'text': 'Text', 'label': None} data_column = builders.DataColumn('text', TextData) label_columns = [builders.LabelColumn('label', CategoryLabel)] - builder = builders.ColumnBuilder( - data_column=data_column, - label_columns=label_columns - ) - actual = builder.build(row, filename='', line_num=1) + actual = self.create_record(row, data_column, label_columns) expected = {'data': 'Text', 'label': []} self.assert_record(actual, expected) + + def test_disallow_no_data_column(self): + row = {'label': 'Label'} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] + with self.assertRaises(FileParseException): + self.create_record(row, data_column, label_columns) + + def test_disallow_empty_text(self): + row = {'text': '', 'label': 'Label'} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] + with self.assertRaises(FileParseException): + self.create_record(row, data_column, label_columns) + + def test_can_load_int_as_text(self): + row = {'text': 5, 'label': 'Label'} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] + actual = self.create_record(row, data_column, label_columns) + expected = {'data': '5', 'label': [{'text': 'Label'}]} + self.assert_record(actual, expected) diff --git a/backend/api/views/upload/builders.py b/backend/api/views/upload/builders.py index 7be2ec2f..83ed294a 100644 --- a/backend/api/views/upload/builders.py +++ b/backend/api/views/upload/builders.py @@ -1,4 +1,5 @@ import abc +from logging import getLogger from typing import Any, Dict, List, Optional, Type, TypeVar from pydantic import ValidationError @@ -8,6 +9,7 @@ from .exception import FileParseException from .label import Label from .readers import Builder, Record +logger = getLogger(__name__) T = TypeVar('T') @@ -23,7 +25,7 @@ class PlainBuilder(Builder): def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> List[Label]: labels = row[name] - labels = [labels] if isinstance(labels, str) else labels + labels = [labels] if isinstance(labels, (str, int)) else labels return [label_class.parse(label) for label in labels] @@ -40,7 +42,7 @@ class Column(abc.ABC): @abc.abstractmethod def __call__(self, row: Dict[Any, Any], filename: str): - raise NotImplementedError('') + raise NotImplementedError('Please implement this method in the subclass.') class DataColumn(Column): @@ -77,7 +79,7 @@ class ColumnBuilder(Builder): try: labels.extend(column(row, filename)) row.pop(column.name) - except (KeyError, ValidationError, TypeError): - pass + except (KeyError, ValidationError, TypeError) as e: + logger.error('Filename: %s, Line: %s, Parsed Data: %s, Error: %s' % (filename, line_num, row, str(e))) return Record(data=data, label=labels, line_num=line_num, meta=row) diff --git a/backend/api/views/upload/label.py b/backend/api/views/upload/label.py index 887ce985..03a04865 100644 --- a/backend/api/views/upload/label.py +++ b/backend/api/views/upload/label.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Optional +from typing import Any, Optional, Union from pydantic import BaseModel, validator @@ -57,7 +57,10 @@ class CategoryLabel(Label): def parse(cls, obj: Any): if isinstance(obj, str): return cls(label=obj) - raise TypeError(f'{obj} is not str.') + elif isinstance(obj, int): + return cls(label=str(obj)) + else: + raise TypeError(f'{obj} is not str.') def create(self, project: Project) -> Optional[LabelModel]: return LabelModel(text=self.label, project=project, task_type='Category') @@ -71,7 +74,7 @@ class CategoryLabel(Label): class OffsetLabel(Label): - label: str + label: Union[str, int] start_offset: int end_offset: int