|
|
@ -4,7 +4,7 @@ from typing import List, Type |
|
|
|
from django.contrib.auth.models import User |
|
|
|
|
|
|
|
from .models import DummyLabelType |
|
|
|
from .pipeline.catalog import RELATION_EXTRACTION |
|
|
|
from .pipeline.catalog import RELATION_EXTRACTION, TextFile, TextLine |
|
|
|
from .pipeline.data import BaseData, BinaryData, TextData |
|
|
|
from .pipeline.exceptions import FileParseException |
|
|
|
from .pipeline.factories import create_parser |
|
|
@ -45,6 +45,21 @@ class Dataset(abc.ABC): |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
class PlainDataset(Dataset): |
|
|
|
def __init__(self, reader: Reader, project: Project, **kwargs): |
|
|
|
super().__init__(reader, project, **kwargs) |
|
|
|
self.example_maker = ExampleMaker(project=project, data_class=TextData) |
|
|
|
|
|
|
|
def save(self, user: User, batch_size: int = 1000): |
|
|
|
for records in self.reader.batch(batch_size): |
|
|
|
examples = self.example_maker.make(records) |
|
|
|
Example.objects.bulk_create(examples) |
|
|
|
|
|
|
|
@property |
|
|
|
def errors(self) -> List[FileParseException]: |
|
|
|
return self.reader.errors + self.example_maker.errors |
|
|
|
|
|
|
|
|
|
|
|
class DatasetWithSingleLabelType(Dataset): |
|
|
|
data_class: Type[BaseData] |
|
|
|
label_class: Type[Label] |
|
|
@ -195,7 +210,7 @@ class CategoryAndSpanDataset(Dataset): |
|
|
|
return self.reader.errors + self.example_maker.errors + self.category_maker.errors + self.span_maker.errors |
|
|
|
|
|
|
|
|
|
|
|
def select_dataset(task: str, project: Project) -> Type[Dataset]: |
|
|
|
def select_dataset(project: Project, task: str, file_format: str) -> Type[Dataset]: |
|
|
|
mapping = { |
|
|
|
DOCUMENT_CLASSIFICATION: TextClassificationDataset, |
|
|
|
SEQUENCE_LABELING: SequenceLabelingDataset, |
|
|
@ -207,11 +222,13 @@ def select_dataset(task: str, project: Project) -> Type[Dataset]: |
|
|
|
} |
|
|
|
if task not in mapping: |
|
|
|
task = project.project_type |
|
|
|
if project.is_text_project and file_format in [TextLine.name, TextFile.name]: |
|
|
|
return PlainDataset |
|
|
|
return mapping[task] |
|
|
|
|
|
|
|
|
|
|
|
def load_dataset(task: str, file_format: str, data_files: List[FileName], project: Project, **kwargs) -> Dataset: |
|
|
|
parser = create_parser(file_format, **kwargs) |
|
|
|
reader = Reader(data_files, parser) |
|
|
|
dataset_class = select_dataset(task, project) |
|
|
|
dataset_class = select_dataset(project, task, file_format) |
|
|
|
return dataset_class(reader, project, **kwargs) |