Browse Source

Add plain dataset

pull/1823/head
Hironsan 3 years ago
parent
commit
3b8941e7e6
2 changed files with 24 additions and 5 deletions
  1. 23
      backend/data_import/datasets.py
  2. 6
      backend/data_import/tests/test_tasks.py

23
backend/data_import/datasets.py

@ -4,7 +4,7 @@ from typing import List, Type
from django.contrib.auth.models import User
from .models import DummyLabelType
from .pipeline.catalog import RELATION_EXTRACTION
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine
from .pipeline.data import BaseData, BinaryData, TextData
from .pipeline.exceptions import FileParseException
from .pipeline.factories import create_parser
@ -45,6 +45,21 @@ class Dataset(abc.ABC):
raise NotImplementedError()
class PlainDataset(Dataset):
def __init__(self, reader: Reader, project: Project, **kwargs):
super().__init__(reader, project, **kwargs)
self.example_maker = ExampleMaker(project=project, data_class=TextData)
def save(self, user: User, batch_size: int = 1000):
for records in self.reader.batch(batch_size):
examples = self.example_maker.make(records)
Example.objects.bulk_create(examples)
@property
def errors(self) -> List[FileParseException]:
return self.reader.errors + self.example_maker.errors
class DatasetWithSingleLabelType(Dataset):
data_class: Type[BaseData]
label_class: Type[Label]
@ -195,7 +210,7 @@ class CategoryAndSpanDataset(Dataset):
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors
def select_dataset(task: str, project: Project) -> Type[Dataset]:
def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]:
mapping = {
DOCUMENT_CLASSIFICATION: TextClassificationDataset,
SEQUENCE_LABELING: SequenceLabelingDataset,
@ -207,11 +222,13 @@ def select_dataset(task: str, project: Project) -> Type[Dataset]:
}
if task not in mapping:
task = project.project_type
if project.is_text_project and file_format in [TextLine.name, TextFile.name]:
return PlainDataset
return mapping[task]
def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset:
parser = create_parser(file_format, **kwargs)
reader = Reader(data_files, parser)
dataset_class = select_dataset(task, project)
dataset_class = select_dataset(project, task, file_format)
return dataset_class(reader, project, **kwargs)

6
backend/data_import/tests/test_tasks.py

@ -132,15 +132,17 @@ class TestImportClassificationData(TestImportData):
filename = "example.txt"
file_format = "TextFile"
dataset = [("exampleA\nexampleB\n\nexampleC\n", [])]
self.import_dataset(filename, file_format, self.task)
response = self.import_dataset(filename, file_format, self.task)
self.assert_examples(dataset)
self.assertEqual(len(response["error"]), 0)
def test_textline(self):
filename = "example.txt"
file_format = "TextLine"
dataset = [("exampleA", []), ("exampleB", []), ("exampleC", [])]
self.import_dataset(filename, file_format, self.task)
response = self.import_dataset(filename, file_format, self.task)
self.assert_examples(dataset)
self.assertEqual(len(response["error"]), 1)
def test_wrong_jsonl(self):
filename = "text_classification/example.json"

Loading…
Cancel
Save