Browse Source

Update ingest_task function

pull/1619/head
Hironsan 2 years ago
parent
commit
d4e216b188
8 changed files with 131 additions and 92 deletions
  1. 17
      backend/api/tasks.py
  2. 2
      backend/api/tests/upload/test_builder.py
  3. 6
      backend/api/tests/upload/test_parser.py
  4. 12
      backend/api/views/upload/builders.py
  5. 32
      backend/api/views/upload/factory.py
  6. 73
      backend/api/views/upload/parsers.py
  7. 42
      backend/api/views/upload/readers.py
  8. 39
      backend/api/views/upload/writers.py

17
backend/api/tasks.py

@ -7,8 +7,8 @@ from django.shortcuts import get_object_or_404
from .models import Project
from .views.download.factory import create_repository, create_writer
from .views.download.service import ExportApplicationService
from .views.upload.factory import (create_cleaner, get_data_class,
get_dataset_class, get_label_class)
from .views.upload.factory import create_bulder, create_cleaner, create_parser
from .views.upload.readers import Reader
from .views.upload.writers import BulkWriter
logger = get_task_logger(__name__)
@ -19,17 +19,12 @@ def ingest_data(user_id, project_id, filenames, format: str, **kwargs):
project = get_object_or_404(Project, pk=project_id)
user = get_object_or_404(get_user_model(), pk=user_id)
dataset_class = get_dataset_class(format)
dataset = dataset_class(
filenames=filenames,
label_class=get_label_class(project.project_type),
data_class=get_data_class(project.project_type),
**kwargs
)
it = iter(dataset)
parser = create_parser(format, **kwargs)
builder = create_bulder(project, **kwargs)
reader = Reader(filenames=filenames, parser=parser, builder=builder)
cleaner = create_cleaner(project)
writer = BulkWriter(batch_size=settings.IMPORT_BATCH_SIZE)
writer.save(it, project, user, cleaner)
writer.save(reader, project, user, cleaner)
return {'error': writer.errors}

2
backend/api/tests/upload/test_builder.py

@ -8,7 +8,7 @@ from ...views.upload.label import CategoryLabel
class TestColumnBuilder(unittest.TestCase):
def assert_record(self, actual, expected):
self.assertEqual(actual.data['text'], expected['data'])
self.assertEqual(actual.data.text, expected['data'])
self.assertEqual(actual.label, expected['label'])
def test_can_load_default_column_names(self):

6
backend/api/tests/upload/test_parser.py

@ -105,7 +105,7 @@ class TestFastTextParser(TestParser):
def test_read(self):
content = '__label__sauce __label__cheese Text'
parser = parsers.FastTextParser()
expected = [{'text': 'Text', 'labels': ['sauce', 'cheese']}]
expected = [{'text': 'Text', 'label': ['sauce', 'cheese']}]
self.assert_record(content, parser, expected)
@ -130,11 +130,11 @@ Blackburn\tI-PER
expected = [
{
'text': 'EU rejects German call to boycott British lamb .',
'labels': [(0, 2, 'ORG'), (11, 17, 'MISC'), (34, 41, 'MISC')]
'label': [(0, 2, 'ORG'), (11, 17, 'MISC'), (34, 41, 'MISC')]
},
{
'text': 'Peter Blackburn',
'labels': [(0, 15, 'PER')]
'label': [(0, 15, 'PER')]
}
]
self.assert_record(content, parser, expected)

12
backend/api/views/upload/builders.py

@ -52,10 +52,7 @@ class DataColumn(Column):
class LabelColumn(Column):
def __call__(self, row: Dict[Any, Any], filename: str) -> List[Label]:
try:
return build_label(row, self.name, self.value_class)
except (KeyError, ValidationError, TypeError):
return []
return build_label(row, self.name, self.value_class)
class ColumnBuilder(Builder):
@ -77,7 +74,10 @@ class ColumnBuilder(Builder):
labels = []
for column in self.label_columns:
labels.extend(column(row, filename))
row.pop(column.name)
try:
labels.extend(column(row, filename))
row.pop(column.name)
except (KeyError, ValidationError, TypeError):
pass
return Record(data=data, label=labels, line_num=line_num, meta=row)

32
backend/api/views/upload/factory.py

@ -1,6 +1,6 @@
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
from . import builders, catalog, cleaners, data, dataset, label, parsers
from . import builders, catalog, cleaners, data, label, parsers, readers
def get_data_class(project_type: str):
@ -11,25 +11,7 @@ def get_data_class(project_type: str):
return data.FileData
def get_dataset_class(format: str):
mapping = {
catalog.TextFile.name: dataset.TextFileDataset,
catalog.TextLine.name: dataset.TextLineDataset,
catalog.CSV.name: dataset.CsvDataset,
catalog.JSONL.name: dataset.JSONLDataset,
catalog.JSON.name: dataset.JSONDataset,
catalog.FastText.name: dataset.FastTextDataset,
catalog.Excel.name: dataset.ExcelDataset,
catalog.CoNLL.name: dataset.CoNLLDataset,
catalog.ImageFile.name: dataset.FileBaseDataset,
catalog.AudioFile.name: dataset.FileBaseDataset,
}
if format not in mapping:
ValueError(f'Invalid format: {format}')
return mapping[format]
def get_parser(file_format: str):
def create_parser(file_format: str, **kwargs):
mapping = {
catalog.TextFile.name: parsers.TextFileParser,
catalog.TextLine.name: parsers.LineParser,
@ -44,7 +26,7 @@ def get_parser(file_format: str):
}
if file_format not in mapping:
raise ValueError(f'Invalid format: {file_format}')
return mapping[file_format]
return mapping[file_format](**kwargs)
def get_label_class(project_type: str):
@ -74,14 +56,14 @@ def create_cleaner(project):
def create_bulder(project, **kwargs):
data_column = builders.DataColumn(
name=kwargs.get('column_data', 'text'),
name=kwargs.get('column_data', readers.DEFAULT_TEXT_COLUMN),
value_class=get_data_class(project.project_type)
)
# Todo: If project is EntityClassification,
# column names are fixed: entities, cats
label_column = builders.DataColumn(
name=kwargs.get('column_label', 'label'),
value_class=get_data_class(project.project_type)
label_column = builders.LabelColumn(
name=kwargs.get('column_label', readers.DEFAULT_LABEL_COLUMN),
value_class=get_label_class(project.project_type)
)
builder = builders.ColumnBuilder(
data_column=data_column,

73
backend/api/views/upload/parsers.py

@ -79,7 +79,7 @@ class PlainParser(Parser):
class LineParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
@ -90,7 +90,7 @@ class LineParser(Parser):
class TextFileParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
@ -101,7 +101,7 @@ class TextFileParser(Parser):
class CSVParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ','):
def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ',', **kwargs):
self.encoding = encoding
self.delimiter = delimiter
@ -115,39 +115,68 @@ class CSVParser(Parser):
class JSONParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
self._errors = []
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
encoding = decide_encoding(filename, self.encoding)
with open(filename, encoding=encoding) as f:
rows = json.load(f)
for line_num, row in enumerate(rows, start=1):
yield row
try:
rows = json.load(f)
for line_num, row in enumerate(rows, start=1):
yield row
except json.decoder.JSONDecodeError as e:
error = FileParseException(filename, line_num=1, message=str(e))
self._errors.append(error)
@property
def errors(self) -> List[FileParseException]:
return self._errors
class JSONLParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING):
def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs):
self.encoding = encoding
self._errors = []
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
reader = LineReader(filename, self.encoding)
for line in reader:
yield json.loads(line)
for line_num, line in enumerate(reader, start=1):
try:
yield json.loads(line)
except json.decoder.JSONDecodeError as e:
error = FileParseException(filename, line_num, str(e))
self._errors.append(error)
@property
def errors(self) -> List[FileParseException]:
return self._errors
class ExcelParser(Parser):
def __init__(self, **kwargs):
self._errors = []
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
rows = pyexcel.iget_records(file_name=filename)
for row in rows:
yield row
try:
for line_num, row in enumerate(rows, start=1):
yield row
except pyexcel.exceptions.FileTypeNotSupported as e:
error = FileParseException(filename, line_num=1, message=str(e))
self._errors.append(error)
@property
def errors(self) -> List[FileParseException]:
return self._errors
class FastTextParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING, label: str = '__label__'):
def __init__(self, encoding: str = DEFAULT_ENCODING, label: str = '__label__', **kwargs):
self.encoding = encoding
self.label = label
@ -168,7 +197,7 @@ class FastTextParser(Parser):
class CoNLLParser(Parser):
def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ' ', scheme: str = 'IOB2'):
def __init__(self, encoding: str = DEFAULT_ENCODING, delimiter: str = ' ', scheme: str = 'IOB2', **kwargs):
self.encoding = encoding
self.delimiter = delimiter
mapping = {
@ -177,12 +206,23 @@ class CoNLLParser(Parser):
'IOBES': IOBES,
'BILOU': BILOU
}
self._errors = []
if scheme in mapping:
self.scheme = mapping[scheme]
else:
raise Exception('The scheme is not supported.')
self.scheme = None
@property
def errors(self) -> List[FileParseException]:
return self._errors
def parse(self, filename: str) -> Iterator[Dict[Any, Any]]:
if not self.scheme:
message = 'The specified scheme is not supported.'
error = FileParseException(filename, line_num=1, message=message)
self._errors.append(error)
return
reader = LineReader(filename, self.encoding)
words, tags = [], []
for line_num, line in enumerate(reader, start=1):
@ -191,7 +231,8 @@ class CoNLLParser(Parser):
tokens = line.split('\t')
if len(tokens) != 2:
message = 'A line must be separated by tab and has two columns.'
raise FileParseException(filename, line_num, message)
self._errors.append(FileParseException(filename, line_num, message))
return
word, tag = tokens
words.append(word)
tags.append(tag)

42
backend/api/views/upload/readers.py

@ -2,11 +2,13 @@ import abc
import collections.abc
from typing import Any, Dict, Iterator, List, Type
from .cleaners import Cleaner
from .data import BaseData
from .exception import FileParseException
from .label import Label
DEFAULT_TEXT_COLUMN = 'text'
DEFAULT_LABEL_COLUMN = 'labels'
DEFAULT_LABEL_COLUMN = 'label'
class Record:
@ -28,9 +30,29 @@ class Record:
def __str__(self):
return f'{self._data}\t{self._label}'
def clean(self, cleaner: Cleaner):
label = cleaner.clean(self._label)
changed = len(label) != len(self.label)
self._label = label
if changed:
raise FileParseException(
filename=self._data.filename,
line_num=self._line_num,
message=cleaner.message
)
@property
def data(self):
return self._data.dict()
return self._data
def create_data(self, project):
return self._data.create(project, self._meta)
def create_label(self, project):
return [label.create(project) for label in self._label]
def create_annotation(self, user, example, mapping):
return [label.create_annotation(user, example, mapping) for label in self._label]
@property
def label(self):
@ -66,6 +88,11 @@ class Parser(abc.ABC):
"""Parses the file and returns the dictionary."""
raise NotImplementedError('Please implement this method in the subclass.')
@property
def errors(self) -> List[FileParseException]:
"""Returns parsing errors."""
return []
class Builder(abc.ABC):
@ -87,10 +114,13 @@ class Reader(BaseReader):
for filename in self.filenames:
rows = self.parser.parse(filename)
for line_num, row in enumerate(rows, start=1):
record = self.builder.build(row, filename, line_num)
yield record
try:
yield self.builder.build(row, filename, line_num)
except FileParseException as e:
self._errors.append(e)
@property
def errors(self):
def errors(self) -> List[FileParseException]:
"""Aggregates parser and builder errors."""
return self._errors
errors = self.parser.errors + self._errors
return errors

39
backend/api/views/upload/writers.py

@ -1,34 +1,24 @@
import abc
import itertools
from collections import defaultdict
from typing import List
from django.conf import settings
from ...models import Example, Label, Project
from .exception import FileParseException, FileParseExceptions
from .exception import FileParseException
from .readers import BaseReader
class Writer(abc.ABC):
@abc.abstractmethod
def save(self, reader: BaseReader):
def save(self, reader: BaseReader, project: Project, user, cleaner):
"""Save the read contents to DB."""
raise NotImplementedError('Please implement this method in the subclass.')
class BulkWriterOld(Writer):
def __init__(self, batch_size: int):
self.batch_size = batch_size
def save(self, reader: BaseReader):
"""Bulk save the read contents."""
pass
def group_by_class(instances):
from collections import defaultdict
groups = defaultdict(list)
for instance in instances:
groups[instance.__class__].append(instance)
@ -79,28 +69,23 @@ class Examples:
klass.objects.bulk_create(instances)
class BulkWriter:
class BulkWriter(Writer):
def __init__(self, batch_size):
self.examples = Examples(batch_size)
self.errors = []
self._errors = []
def save(self, dataset, project, user, cleaner):
def save(self, reader: BaseReader, project, user, cleaner):
it = iter(reader)
while True:
try:
example = next(dataset)
example = next(it)
except StopIteration:
break
except FileParseException as err:
self.errors.append(err.dict())
continue
except FileParseExceptions as err:
self.errors.append(list(err))
continue
try:
example.clean(cleaner)
except FileParseException as err:
self.errors.append(err.dict())
self._errors.append(err)
self.examples.add(example)
if self.examples.is_full():
@ -109,6 +94,12 @@ class BulkWriter:
if not self.examples.is_empty():
self.create(project, user)
self.examples.clear()
self._errors.extend(reader.errors)
@property
def errors(self) -> List[FileParseException]:
self._errors.sort(key=lambda e: e.line_num)
return self._errors
def create(self, project, user):
self.examples.save_label(project)

Loading…
Cancel
Save