From 6bb657a6655c8c674dffe5c1761027cc65df73b9 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Wed, 18 May 2022 15:41:33 +0900 Subject: [PATCH] Change the default value of column data and column label --- backend/data_import/celery_tasks.py | 6 +++++- backend/data_import/datasets.py | 10 +++++----- backend/data_import/pipeline/exceptions.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/backend/data_import/celery_tasks.py b/backend/data_import/celery_tasks.py index d7585d45..2f9d08b0 100644 --- a/backend/data_import/celery_tasks.py +++ b/backend/data_import/celery_tasks.py @@ -10,7 +10,11 @@ from django_drf_filepond.models import TemporaryUpload from .datasets import load_dataset from .pipeline.catalog import AudioFile, ImageFile -from .pipeline.exceptions import FileTypeException, MaximumFileSizeException, FileImportException +from .pipeline.exceptions import ( + FileImportException, + FileTypeException, + MaximumFileSizeException, +) from .pipeline.readers import FileName from projects.models import Project diff --git a/backend/data_import/datasets.py b/backend/data_import/datasets.py index 140f5e04..d424bc14 100644 --- a/backend/data_import/datasets.py +++ b/backend/data_import/datasets.py @@ -56,11 +56,11 @@ class DatasetWithSingleLabelType(Dataset): self.example_maker = ExampleMaker( project=project, data_class=self.data_class, - column_data=kwargs.get("column_data", DEFAULT_TEXT_COLUMN), - exclude_columns=[kwargs.get("column_label", DEFAULT_LABEL_COLUMN)], + column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, + exclude_columns=[kwargs.get("column_label") or DEFAULT_LABEL_COLUMN], ) self.label_maker = LabelMaker( - column=kwargs.get("column_label", DEFAULT_LABEL_COLUMN), label_class=self.label_class + column=kwargs.get("column_label") or DEFAULT_LABEL_COLUMN, label_class=self.label_class ) def save(self, user: User, batch_size: int = 1000): @@ -126,7 +126,7 @@ class RelationExtractionDataset(Dataset): self.example_maker = ExampleMaker( project=project, data_class=TextData, - column_data=kwargs.get("column_data", DEFAULT_TEXT_COLUMN), + column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, exclude_columns=["entities", "relations"], ) self.span_maker = LabelMaker(column="entities", label_class=SpanLabel) @@ -164,7 +164,7 @@ class CategoryAndSpanDataset(Dataset): self.example_maker = ExampleMaker( project=project, data_class=TextData, - column_data=kwargs.get("column_data", DEFAULT_TEXT_COLUMN), + column_data=kwargs.get("column_data") or DEFAULT_TEXT_COLUMN, exclude_columns=["cats", "entities"], ) self.category_maker = LabelMaker(column="cats", label_class=CategoryLabel) diff --git a/backend/data_import/pipeline/exceptions.py b/backend/data_import/pipeline/exceptions.py index fad70067..efab9ca7 100644 --- a/backend/data_import/pipeline/exceptions.py +++ b/backend/data_import/pipeline/exceptions.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Any, Dict class FileImportException(Exception):