From 2d0091e3a2bad5be7e59fe548c703553c10ad4f9 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 8 Apr 2021 08:33:31 +0900 Subject: [PATCH] Add exception handling --- app/api/views/upload/dataset.py | 74 +++++++++++++++------------------ 1 file changed, 33 insertions(+), 41 deletions(-) diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index f7d02b54..24f930bf 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -1,11 +1,11 @@ import csv import json -from itertools import chain from typing import Dict, Iterator, List, Optional, Type import pyexcel from .data import BaseData +from .exception import FileParseException from .label import Label from .labels import Labels @@ -25,32 +25,20 @@ class Record: @property def data(self): - return self._data - - @property - def annotation(self): - return Labels(self._label) - - @property - def label(self): - return [label.name for label in self._label if label.has_name() and label.name] - - -class Records: - - def __init__(self, records: List[Record]): - self.records = records - - def data(self): - return [r.data.dict() for r in self.records] + return self._data.dict() def annotation(self, mapping: Dict[str, int]): - return [r.annotation.replace_label(mapping).dict() for r in self.records] + labels = Labels(self._label) + labels = labels.replace_label(mapping) + return labels.dict() + @property def label(self): - labels = set(chain(*[r.label for r in self.records])) return [ - {'text': label} for label in labels + { + 'text': label.name + } for label in self._label + if label.has_name() and label.name ] @@ -72,15 +60,6 @@ class Dataset: for filename in self.filenames: yield from self.load(filename) - def batch(self, batch_size) -> Records: - records = [] - for record in self: - records.append(record) - if len(records) == batch_size: - yield Records(records) - records = [] - yield Records(records) - def load(self, filename: str) -> Iterator[Record]: """Loads a file content.""" with open(filename, encoding=self.encoding) as f: @@ -88,7 +67,11 @@ class Dataset: record = Record(data=data) yield record - def from_row(self, filename: str, row: Dict) -> Record: + def from_row(self, filename: str, row: Dict, line_num: int) -> Record: + column_data = self.kwargs.get('column_data', 'text') + if column_data not in row: + message = f'{column_data} does not exist.' + raise FileParseException(filename, line_num, message) text = row.pop(self.kwargs.get('column_data', 'text')) label = row.pop(self.kwargs.get('column_label', 'label'), []) label = [label] if isinstance(label, str) else label @@ -132,9 +115,15 @@ class CsvDataset(Dataset): delimiter = self.kwargs.get('delimiter', ',') reader = csv.reader(f, delimiter=delimiter) header = next(reader) - for row in reader: + + column_data = self.kwargs.get('column_data', 'text') + if column_data not in header: + message = f'{column_data} does not exist in the header: {header}' + raise FileParseException(filename, 1, message) + + for line_num, row in enumerate(reader, start=2): row = dict(zip(header, row)) - yield self.from_row(filename, row) + yield self.from_row(filename, row, line_num) class JSONDataset(Dataset): @@ -142,36 +131,39 @@ class JSONDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: dataset = json.load(f) - for row in dataset: - yield self.from_row(filename, row) + for line_num, row in enumerate(dataset, start=1): + yield self.from_row(filename, row, line_num) class JSONLDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: - for line in f: + for line_num, line in enumerate(f, start=1): row = json.loads(line) - yield self.from_row(filename, row) + yield self.from_row(filename, row, line_num) class ExcelDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: records = pyexcel.iget_records(filename) - for row in records: - yield self.from_row(filename, row) + for line_num, row in enumerate(records, start=1): + yield self.from_row(filename, row, line_num) class FastTextDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: with open(filename, encoding=self.encoding) as f: - for i, line in enumerate(f, start=1): + for line_num, line in enumerate(f, start=1): labels = [] tokens = [] for token in line.rstrip().split(' '): if token.startswith('__label__'): + if token == '__label__': + message = 'Label name is empty.' + raise FileParseException(filename, line_num, message) label_name = token[len('__label__'):] labels.append(self.label_class.parse(label_name)) else: