From 09f7ec65bdb421a14c617b4db7d7aa23d4106c00 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 17 Dec 2021 14:35:38 +0900 Subject: [PATCH] Update ColumnBuilder --- backend/api/tests/upload/test_builder.py | 16 ++--- backend/api/views/upload/builders.py | 81 +++++++++++++++++------- backend/api/views/upload/data.py | 13 ++-- backend/api/views/upload/dataset.py | 12 ++-- backend/api/views/upload/readers.py | 9 ++- 5 files changed, 87 insertions(+), 44 deletions(-) diff --git a/backend/api/tests/upload/test_builder.py b/backend/api/tests/upload/test_builder.py index a8d50129..bd42f3b5 100644 --- a/backend/api/tests/upload/test_builder.py +++ b/backend/api/tests/upload/test_builder.py @@ -13,11 +13,11 @@ class TestColumnBuilder(unittest.TestCase): def test_can_load_default_column_names(self): row = {'text': 'Text', 'label': 'Label'} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] builder = builders.ColumnBuilder( - data_class=TextData, - label_class=CategoryLabel, - text_column='text', - label_column='label' + data_column=data_column, + label_columns=label_columns ) actual = builder.build(row, filename='', line_num=1) expected = {'data': 'Text', 'label': [{'text': 'Label'}]} @@ -25,11 +25,11 @@ class TestColumnBuilder(unittest.TestCase): def test_can_load_only_text_column(self): row = {'text': 'Text', 'label': None} + data_column = builders.DataColumn('text', TextData) + label_columns = [builders.LabelColumn('label', CategoryLabel)] builder = builders.ColumnBuilder( - data_class=TextData, - label_class=CategoryLabel, - text_column='text', - label_column='label' + data_column=data_column, + label_columns=label_columns ) actual = builder.build(row, filename='', line_num=1) expected = {'data': 'Text', 'label': []} diff --git a/backend/api/views/upload/builders.py b/backend/api/views/upload/builders.py index 669290b1..8eea86f1 100644 --- a/backend/api/views/upload/builders.py +++ b/backend/api/views/upload/builders.py @@ -1,11 +1,14 @@ -from typing import Any, Dict, Type +import abc +from typing import Any, Dict, List, Optional, Type, TypeVar from pydantic import ValidationError from .data import BaseData from .exception import FileParseException from .label import Label -from .readers import DEFAULT_LABEL_COLUMN, DEFAULT_TEXT_COLUMN, Builder, Record +from .readers import Builder, Record + +T = TypeVar('T') class PlainBuilder(Builder): @@ -18,33 +21,63 @@ class PlainBuilder(Builder): yield Record(data=data) -class ColumnBuilder(Builder): +def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> List[Label]: + labels = row[name] + labels = [labels] if isinstance(labels, str) else labels + return [label_class.parse(label) for label in labels] - def __init__(self, - data_class: Type[BaseData], - label_class: Type[Label], - text_column: str = DEFAULT_TEXT_COLUMN, - label_column: str = DEFAULT_LABEL_COLUMN): - self.data_class = data_class - self.label_class = label_class - self.text_column = text_column - self.label_column = label_column - def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record: - if self.text_column not in row: - message = f'{self.text_column} does not exist.' - raise FileParseException(filename, line_num, message) - text = row.pop(self.text_column) - label = row.pop(self.label_column, []) - label = [label] if isinstance(label, str) else label +def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filename: str) -> BaseData: + data = row[name] + return data_class.parse(text=data, filename=filename) + + +class Column(abc.ABC): + + def __init__(self, name: str, value_class: Type[T]): + self.name = name + self.value_class = value_class + + @abc.abstractmethod + def __call__(self, row: Dict[Any, Any], filename: str): + raise NotImplementedError('') + + +class DataColumn(Column): + + def __call__(self, row: Dict[Any, Any], filename: str) -> BaseData: + return build_data(row, self.name, self.value_class, filename) + + +class LabelColumn(Column): + + def __call__(self, row: Dict[Any, Any], filename: str) -> List[Label]: try: - label = [self.label_class.parse(o) for o in label] - except (ValidationError, TypeError): - label = [] + return build_label(row, self.name, self.value_class) + except (KeyError, ValidationError, TypeError): + return [] + + +class ColumnBuilder(Builder): + def __init__(self, data_column: Column, label_columns: Optional[List[Column]] = None): + self.data_column = data_column + self.label_columns = label_columns or [] + + def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record: try: - data = self.data_class.parse(text=text, filename=filename, meta=row) - return Record(data=data, label=label, line_num=line_num) + data = self.data_column(row, filename) + row.pop(self.data_column.name) + except KeyError: + message = f'{self.data_column.name} field does not exist.' + raise FileParseException(filename, line_num, message) except ValidationError: message = 'The empty text is not allowed.' raise FileParseException(filename, line_num, message) + + labels = [] + for column in self.label_columns: + labels.extend(column(row, filename)) + row.pop(column.name) + + return Record(data=data, label=labels, line_num=line_num, meta=row) diff --git a/backend/api/views/upload/data.py b/backend/api/views/upload/data.py index ac98bc12..5a069950 100644 --- a/backend/api/views/upload/data.py +++ b/backend/api/views/upload/data.py @@ -1,6 +1,6 @@ import abc import uuid -from typing import Dict +from typing import Any, Dict from pydantic import BaseModel, validator @@ -9,7 +9,6 @@ from ...models import Example, Project class BaseData(BaseModel, abc.ABC): filename: str - meta: Dict = {} @classmethod def parse(cls, **kwargs): @@ -19,7 +18,7 @@ class BaseData(BaseModel, abc.ABC): return hash(tuple(self.dict())) @abc.abstractmethod - def create(self, project: Project) -> Example: + def create(self, project: Project, meta: Dict[Any, Any]) -> Example: raise NotImplementedError('Please implement this method in the subclass.') @@ -33,22 +32,22 @@ class TextData(BaseData): else: raise ValueError('is not empty.') - def create(self, project: Project) -> Example: + def create(self, project: Project, meta: Dict[Any, Any]) -> Example: return Example( uuid=uuid.uuid4(), project=project, filename=self.filename, text=self.text, - meta=self.meta + meta=meta ) class FileData(BaseData): - def create(self, project: Project) -> Example: + def create(self, project: Project, meta: Dict[Any, Any]) -> Example: return Example( uuid=uuid.uuid4(), project=project, filename=self.filename, - meta=self.meta + meta=meta ) diff --git a/backend/api/views/upload/dataset.py b/backend/api/views/upload/dataset.py index 7a4be84c..8b70e915 100644 --- a/backend/api/views/upload/dataset.py +++ b/backend/api/views/upload/dataset.py @@ -2,7 +2,7 @@ import csv import io import json import os -from typing import Dict, Iterator, List, Optional, Type +from typing import Any, Dict, Iterator, List, Optional, Type import chardet import pyexcel @@ -22,11 +22,15 @@ class Record: def __init__(self, data: Type[BaseData], label: List[Label] = None, + meta: Dict[Any, Any] = None, line_num: int = -1): if label is None: label = [] + if meta is None: + meta = {} self._data = data self._label = label + self._meta = meta self._line_num = line_num def __str__(self): @@ -48,7 +52,7 @@ class Record: return self._data def create_data(self, project): - return self._data.create(project) + return self._data.create(project, self._meta) def create_label(self, project): return [label.create(project) for label in self._label] @@ -139,12 +143,12 @@ class Dataset: label = [] try: - data = self.data_class.parse(text=text, filename=filename, meta=row) + data = self.data_class.parse(text=text, filename=filename) except ValidationError: message = 'The empty text is not allowed.' raise FileParseException(filename, line_num, message) - record = Record(data=data, label=label, line_num=line_num) + record = Record(data=data, label=label, line_num=line_num, meta=row) return record diff --git a/backend/api/views/upload/readers.py b/backend/api/views/upload/readers.py index 058f37ec..45b87f0a 100644 --- a/backend/api/views/upload/readers.py +++ b/backend/api/views/upload/readers.py @@ -11,11 +11,18 @@ DEFAULT_LABEL_COLUMN = 'labels' class Record: - def __init__(self, data: Type[BaseData], label: List[Label] = None, line_num: int = -1): + def __init__(self, + data: Type[BaseData], + label: List[Label] = None, + meta: Dict[Any, Any] = None, + line_num: int = -1): if label is None: label = [] + if meta is None: + meta = {} self._data = data self._label = label + self._meta = meta self._line_num = line_num def __str__(self):