diff --git a/app/api/views/upload/dataset.py b/app/api/views/upload/dataset.py index c14610df..fa0f7789 100644 --- a/app/api/views/upload/dataset.py +++ b/app/api/views/upload/dataset.py @@ -24,10 +24,12 @@ class Dataset: def __init__(self, filenames: List[str], + encoding: Optional[str] = None, column_data: str = 'text', column_label: str = 'label', **kwargs): self.filenames = filenames + self.encoding = encoding self.column_data = column_data self.column_label = column_label self.kwargs = kwargs @@ -38,7 +40,9 @@ class Dataset: def load(self, filename: str) -> Iterator[Record]: """Loads a file content.""" - raise NotImplementedError() + with open(filename, encoding=self.encoding) as f: + record = Record(filename=filename, data=f.read()) + yield record def from_row(self, filename: str, row: Dict) -> Record: data = row.pop(self.column_data) @@ -57,7 +61,7 @@ class FileBaseDataset(Dataset): class TextFileDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: record = Record(filename=filename, data=f.read()) yield record @@ -65,7 +69,7 @@ class TextFileDataset(Dataset): class TextLineDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: for line in f: record = Record(filename=filename, data=line.rstrip()) yield record @@ -74,7 +78,7 @@ class TextLineDataset(Dataset): class CsvDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: delimiter = self.kwargs.get('delimiter', ',') reader = csv.reader(f, delimiter=delimiter) header = next(reader) @@ -86,7 +90,7 @@ class CsvDataset(Dataset): class JSONDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: dataset = json.load(f) for row in dataset: yield self.from_row(filename, row) @@ -95,7 +99,7 @@ class JSONDataset(Dataset): class JSONLDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: for line in f: row = json.loads(line) yield self.from_row(filename, row) @@ -112,7 +116,7 @@ class ExcelDataset(Dataset): class FastTextDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: for i, line in enumerate(f, start=1): labels = [] tokens = [] @@ -129,5 +133,5 @@ class FastTextDataset(Dataset): class ConllDataset(Dataset): def load(self, filename: str) -> Iterator[Record]: - with open(filename) as f: + with open(filename, encoding=self.encoding) as f: pass