Browse Source

Fix image import

pull/1665/head
Hironsan 2 years ago
parent
commit
a2c01cca2e
5 changed files with 24 additions and 7 deletions
  1. 4
      backend/data_import/celery_tasks.py
  2. 3
      backend/data_import/pipeline/builders.py
  3. 7
      backend/data_import/pipeline/factories.py
  4. 3
      backend/data_import/pipeline/parsers.py
  5. 14
      backend/data_import/tests/test_tasks.py

4
backend/data_import/celery_tasks.py

@ -4,7 +4,7 @@ from django.contrib.auth import get_user_model
from django.shortcuts import get_object_or_404
from api.models import Project
from .pipeline.factories import create_parser, create_bulder, create_cleaner
from .pipeline.factories import create_parser, create_builder, create_cleaner
from .pipeline.readers import Reader
from .pipeline.writers import BulkWriter
@ -15,7 +15,7 @@ def import_dataset(user_id, project_id, filenames, file_format: str, **kwargs):
user = get_object_or_404(get_user_model(), pk=user_id)
parser = create_parser(file_format, **kwargs)
builder = create_bulder(project, **kwargs)
builder = create_builder(project, **kwargs)
reader = Reader(filenames=filenames, parser=parser, builder=builder)
cleaner = create_cleaner(project)
writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE)

3
backend/data_import/pipeline/builders.py

@ -16,11 +16,12 @@ T = TypeVar('T')
class PlainBuilder(Builder):
def __init__(self, data_class: Type[BaseData]):
print(data_class)
self.data_class = data_class
def build(self, row: Dict[Any, Any], filename: str, line_num: int) -> Record:
data = self.data_class.parse(filename=filename)
yield Record(data=data)
return Record(data=data)
def build_label(row: Dict[Any, Any], name: str, label_class: Type[Label]) -> List[Label]:

7
backend/data_import/pipeline/factories.py

@ -55,12 +55,15 @@ def create_cleaner(project):
IMAGE_CLASSIFICATION: cleaners.CategoryCleaner
}
if project.project_type not in mapping:
ValueError(f'Invalid project type: {project.project_type}')
return cleaners.Cleaner(project)
cleaner_class = mapping.get(project.project_type, cleaners.Cleaner)
return cleaner_class(project)
def create_bulder(project, **kwargs):
def create_builder(project, **kwargs):
if not project.is_text_project:
return builders.PlainBuilder(data_class=get_data_class(project.project_type))
data_column = builders.DataColumn(
name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN,
value_class=get_data_class(project.project_type)

3
backend/data_import/pipeline/parsers.py

@ -96,6 +96,9 @@ class PlainParser(Parser):
This is for a task without any text.
"""
def __init__(self, **kwargs):
self.kwargs = kwargs
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
yield {}

14
backend/data_import/tests/test_tasks.py

@ -5,7 +5,7 @@ from django.test import TestCase
from data_import.celery_tasks import import_dataset
from api.models import (DOCUMENT_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING)
SEQUENCE_LABELING, IMAGE_CLASSIFICATION)
from examples.models import Example
from label_types.models import CategoryType, SpanType
from labels.models import Category, Span
@ -242,7 +242,7 @@ class TestImportSeq2seqData(TestImportData):
self.assert_examples(dataset)
class TextImportIntentDetectionAndSlotFillingData(TestImportData):
class TestImportIntentDetectionAndSlotFillingData(TestImportData):
task = INTENT_DETECTION_AND_SLOT_FILLING
def assert_examples(self, dataset):
@ -265,3 +265,13 @@ class TextImportIntentDetectionAndSlotFillingData(TestImportData):
]
self.import_dataset(filename, file_format)
self.assert_examples(dataset)
class TestImportImageClassificationData(TestImportData):
task = IMAGE_CLASSIFICATION
def test_example(self):
filename = 'images/1500x500.jpeg'
file_format = 'ImageFile'
self.import_dataset(filename, file_format)
self.assertEqual(Example.objects.count(), 1)
Loading…
Cancel
Save