Browse Source

Sort exported labels, fix #1466

pull/1558/head
Hironsan 3 years ago
parent
commit
cf9ee87a2a
2 changed files with 12 additions and 1 deletions
  1. 10
      backend/api/tests/download/test_writer.py
  2. 3
      backend/api/views/download/writer.py

10
backend/api/tests/download/test_writer.py

@ -32,6 +32,16 @@ class TestCSVWriter(unittest.TestCase):
} }
self.assertEqual(line, expected) self.assertEqual(line, expected)
def test_label_order(self):
writer = CsvWriter('.')
record1 = Record(id=0, data='', label=['labelA', 'labelB'], user='', metadata={})
record2 = Record(id=0, data='', label=['labelB', 'labelA'], user='', metadata={})
line1 = writer.create_line(record1)
line2 = writer.create_line(record2)
expected = 'labelA#labelB'
self.assertEqual(line1['label'], expected)
self.assertEqual(line2['label'], expected)
@patch('os.remove') @patch('os.remove')
@patch('zipfile.ZipFile') @patch('zipfile.ZipFile')
@patch('csv.DictWriter.writerow') @patch('csv.DictWriter.writerow')

3
backend/api/views/download/writer.py

@ -84,7 +84,7 @@ class CsvWriter(BaseWriter):
return { return {
'id': record.id, 'id': record.id,
'data': record.data, 'data': record.data,
'label': '#'.join(record.label),
'label': '#'.join(sorted(record.label)),
**record.metadata **record.metadata
} }
@ -144,6 +144,7 @@ class FastTextWriter(LineWriter):
def create_line(self, record): def create_line(self, record):
line = [f'__label__{label}' for label in record.label] line = [f'__label__{label}' for label in record.label]
line.sort()
line.append(record.data) line.append(record.data)
line = ' '.join(line) line = ' '.join(line)
return line return line
Loading…
Cancel
Save