From 3b8941e7e6f98ec9c6760fb6546a33d3d1c8e45a Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 19 May 2022 08:17:47 +0900 Subject: [PATCH] Add plain dataset --- backend/data_import/datasets.py | 23 ++++++++++++++++++++--- backend/data_import/tests/test_tasks.py | 6 ++++-- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/backend/data_import/datasets.py b/backend/data_import/datasets.py index d8876a67..76e8b3e2 100644 --- a/backend/data_import/datasets.py +++ b/backend/data_import/datasets.py @@ -4,7 +4,7 @@ from typing import List, Type from django.contrib.auth.models import User from .models import DummyLabelType -from .pipeline.catalog import RELATION_EXTRACTION +from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine from .pipeline.data import BaseData, BinaryData, TextData from .pipeline.exceptions import FileParseException from .pipeline.factories import create_parser @@ -45,6 +45,21 @@ class Dataset(abc.ABC): raise NotImplementedError() +class PlainDataset(Dataset): + def __init__(self, reader: Reader, project: Project, **kwargs): + super().__init__(reader, project, **kwargs) + self.example_maker = ExampleMaker(project=project, data_class=TextData) + + def save(self, user: User, batch_size: int = 1000): + for records in self.reader.batch(batch_size): + examples = self.example_maker.make(records) + Example.objects.bulk_create(examples) + + @property + def errors(self) -> List[FileParseException]: + return self.reader.errors + self.example_maker.errors + + class DatasetWithSingleLabelType(Dataset): data_class: Type[BaseData] label_class: Type[Label] @@ -195,7 +210,7 @@ class CategoryAndSpanDataset(Dataset): return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors -def select_dataset(task: str, project: Project) -> Type[Dataset]: +def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]: mapping = { DOCUMENT_CLASSIFICATION: TextClassificationDataset, SEQUENCE_LABELING: SequenceLabelingDataset, @@ -207,11 +222,13 @@ def select_dataset(task: str, project: Project) -> Type[Dataset]: } if task not in mapping: task = project.project_type + if project.is_text_project and file_format in [TextLine.name, TextFile.name]: + return PlainDataset return mapping[task] def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset: parser = create_parser(file_format, **kwargs) reader = Reader(data_files, parser) - dataset_class = select_dataset(task, project) + dataset_class = select_dataset(project, task, file_format) return dataset_class(reader, project, **kwargs) diff --git a/backend/data_import/tests/test_tasks.py b/backend/data_import/tests/test_tasks.py index 54c71434..888a0bd8 100644 --- a/backend/data_import/tests/test_tasks.py +++ b/backend/data_import/tests/test_tasks.py @@ -132,15 +132,17 @@ class TestImportClassificationData(TestImportData): filename = "example.txt" file_format = "TextFile" dataset = [("exampleA\nexampleB\n\nexampleC\n", [])] - self.import_dataset(filename, file_format, self.task) + response = self.import_dataset(filename, file_format, self.task) self.assert_examples(dataset) + self.assertEqual(len(response["error"]), 0) def test_textline(self): filename = "example.txt" file_format = "TextLine" dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])] - self.import_dataset(filename, file_format, self.task) + response = self.import_dataset(filename, file_format, self.task) self.assert_examples(dataset) + self.assertEqual(len(response["error"]), 1) def test_wrong_jsonl(self): filename = "text_classification/example.json"