Browse Source

Enable to store upload file name

pull/1779/head
Hironsan 2 years ago
parent
commit
921e43c00a
5 changed files with 36 additions and 24 deletions
  1. 2
      backend/data_import/celery_tasks.py
  2. 22
      backend/data_import/pipeline/builders.py
  3. 14
      backend/data_import/pipeline/data.py
  4. 18
      backend/data_import/pipeline/readers.py
  5. 4
      backend/data_import/tests/test_builder.py

2
backend/data_import/celery_tasks.py

@ -55,7 +55,7 @@ def import_dataset(user_id, project_id, file_format: str, upload_ids: List[str],
upload_ids, errors = check_uploaded_files(upload_ids, file_format)
temporary_uploads = TemporaryUpload.objects.filter(upload_id__in=upload_ids)
filenames = [
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, original_name=tu.upload_name)
FileName(full_path=tu.get_file_path(), generated_name=tu.file.name, upload_name=tu.upload_name)
for tu in temporary_uploads
]

22
backend/data_import/pipeline/builders.py

@ -7,7 +7,7 @@ from pydantic import ValidationError
from .data import BaseData
from .exceptions import FileParseException
from .labels import Label
from .readers import Builder, Record
from .readers import Builder, FileName, Record
logger = getLogger(__name__)
@ -16,8 +16,8 @@ class PlainBuilder(Builder):
def __init__(self, data_class: Type[BaseData]):
self.data_class = data_class
def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
data = self.data_class.parse(filename=filename)
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record:
data = self.data_class.parse(filename=filename.generated_name, upload_name=filename.upload_name)
return Record(data=data)
@ -27,9 +27,9 @@ def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> Lis
return [label_class.parse(label) for label in labels]
def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filename: str) -> BaseData:
def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filename: FileName) -> BaseData:
data = row[name]
return data_class.parse(text=data, filename=filename)
return data_class.parse(text=data, filename=filename.generated_name, upload_name=filename.upload_name)
class Column(abc.ABC):
@ -39,17 +39,17 @@ class Column(abc.ABC):
self.value_class = value_class
@abc.abstractmethod
def __call__(self, row: Dict[Any, Any], filename: str):
def __call__(self, row: Dict[Any, Any], filename: FileName):
raise NotImplementedError("Please implement this method in the subclass.")
class DataColumn(Column):
def __call__(self, row: Dict[Any, Any], filename: str) -> BaseData:
def __call__(self, row: Dict[Any, Any], filename: FileName) -> BaseData:
return build_data(row, self.name, self.value_class, filename)
class LabelColumn(Column):
def __call__(self, row: Dict[Any, Any], filename: str) -> List[Label]:
def __call__(self, row: Dict[Any, Any], filename: FileName) -> List[Label]:
return build_label(row, self.name, self.value_class)
@ -58,16 +58,16 @@ class ColumnBuilder(Builder):
self.data_column = data_column
self.label_columns = label_columns or []
def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record:
try:
data = self.data_column(row, filename)
row.pop(self.data_column.name)
except KeyError:
message = f"{self.data_column.name} field does not exist."
raise FileParseException(filename, line_num, message)
raise FileParseException(filename.upload_name, line_num, message)
except ValidationError:
message = "The empty text is not allowed."
raise FileParseException(filename, line_num, message)
raise FileParseException(filename.upload_name, line_num, message)
labels = []
for column in self.label_columns:

14
backend/data_import/pipeline/data.py

@ -10,6 +10,7 @@ from projects.models import Project
class BaseData(BaseModel, abc.ABC):
filename: str
upload_name: str
@classmethod
def parse(cls, **kwargs):
@ -34,9 +35,18 @@ class TextData(BaseData):
raise ValueError("is not empty.")
def create(self, project: Project, meta: Dict[Any, Any]) -> Example:
return Example(uuid=uuid.uuid4(), project=project, filename=self.filename, text=self.text, meta=meta)
return Example(
uuid=uuid.uuid4(),
project=project,
filename=self.filename,
upload_name=self.upload_name,
text=self.text,
meta=meta,
)
class FileData(BaseData):
def create(self, project: Project, meta: Dict[Any, Any]) -> Example:
return Example(uuid=uuid.uuid4(), project=project, filename=self.filename, meta=meta)
return Example(
uuid=uuid.uuid4(), project=project, filename=self.filename, upload_name=self.upload_name, meta=meta
)

18
backend/data_import/pipeline/readers.py

@ -87,22 +87,22 @@ class Parser(abc.ABC):
return []
@dataclasses.dataclass
class FileName:
full_path: str
generated_name: str
upload_name: str
class Builder(abc.ABC):
"""The abstract Record builder."""
@abc.abstractmethod
def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
def build(self, row: Dict[Any, Any], filename: FileName, line_num: int) -> Record:
"""Builds the record from the dictionary."""
raise NotImplementedError("Please implement this method in the subclass.")
@dataclasses.dataclass
class FileName:
full_path: str
generated_name: str
original_name: str
class Reader(BaseReader):
def __init__(self, filenames: List[FileName], parser: Parser, builder: Builder):
self.filenames = filenames
@ -115,7 +115,7 @@ class Reader(BaseReader):
rows = self.parser.parse(filename.full_path)
for line_num, row in enumerate(rows, start=1):
try:
yield self.builder.build(row, filename.generated_name, line_num)
yield self.builder.build(row, filename, line_num)
except FileParseException as e:
self._errors.append(e)

4
backend/data_import/tests/test_builder.py

@ -5,6 +5,7 @@ from data_import.pipeline import builders
from data_import.pipeline.data import TextData
from data_import.pipeline.exceptions import FileParseException
from data_import.pipeline.labels import CategoryLabel, SpanLabel
from data_import.pipeline.readers import FileName
class TestColumnBuilder(unittest.TestCase):
@ -14,7 +15,8 @@ class TestColumnBuilder(unittest.TestCase):
def create_record(self, row, data_column: builders.DataColumn, label_columns: Optional[List[builders.Column]]):
builder = builders.ColumnBuilder(data_column=data_column, label_columns=label_columns)
return builder.build(row, filename="", line_num=1)
filename = FileName("", "", "")
return builder.build(row, filename=filename, line_num=1)
def test_can_load_default_column_names(self):
row = {"text": "Text", "label": "Label"}

Loading…
Cancel
Save