Browse Source

Merge pull request #223 from CatalystCode/enhancement/data-import-coverage

Enhancement/Ensure data pagination is covered in tests
pull/239/head
Hiroki Nakayama 5 years ago
committed by GitHub
parent
commit
427f59b23d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 13 deletions
  1. 2
      app/app/settings.py
  2. 1
      app/server/tests/test_api.py
  3. 18
      app/server/utils.py

2
app/app/settings.py

@ -257,7 +257,7 @@ SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https')
# Size of the batch for creating documents # Size of the batch for creating documents
# on the import phase # on the import phase
IMPORT_BATCH_SIZE = 500
IMPORT_BATCH_SIZE = env.int('IMPORT_BATCH_SIZE', 500)
GOOGLE_TRACKING_ID = env('GOOGLE_TRACKING_ID', 'UA-125643874-2') GOOGLE_TRACKING_ID = env('GOOGLE_TRACKING_ID', 'UA-125643874-2')

1
app/server/tests/test_api.py

@ -904,6 +904,7 @@ class TestFeatures(APITestCase):
self.assertFalse(response.json().get('cloud_upload')) self.assertFalse(response.json().get('cloud_upload'))
@override_settings(IMPORT_BATCH_SIZE=2)
class TestParser(APITestCase): class TestParser(APITestCase):
def parser_helper(self, filename, parser, include_label=True): def parser_helper(self, filename, parser, include_label=True):

18
app/server/utils.py

@ -7,10 +7,10 @@ from collections import defaultdict
from random import Random from random import Random
from django.db import transaction from django.db import transaction
from django.conf import settings
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from seqeval.metrics.sequence_labeling import get_entities from seqeval.metrics.sequence_labeling import get_entities
from app.settings import IMPORT_BATCH_SIZE
from .exceptions import FileParseException from .exceptions import FileParseException
from .models import Label from .models import Label
from .serializers import DocumentSerializer, LabelSerializer from .serializers import DocumentSerializer, LabelSerializer
@ -242,19 +242,13 @@ class CoNLLParser(FileParser):
``` ```
""" """
def parse(self, file): def parse(self, file):
"""Store json for seq2seq.
Return format:
{"text": "Python is awesome!", "labels": [[0, 6, "Product"],]}
...
"""
words, tags = [], [] words, tags = [], []
data = [] data = []
file = io.TextIOWrapper(file, encoding='utf-8')
for i, line in enumerate(file, start=1): for i, line in enumerate(file, start=1):
if len(data) >= IMPORT_BATCH_SIZE:
if len(data) >= settings.IMPORT_BATCH_SIZE:
yield data yield data
data = [] data = []
line = line.decode('utf-8')
line = line.strip() line = line.strip()
if line: if line:
try: try:
@ -301,7 +295,7 @@ class PlainTextParser(FileParser):
def parse(self, file): def parse(self, file):
file = io.TextIOWrapper(file, encoding='utf-8') file = io.TextIOWrapper(file, encoding='utf-8')
while True: while True:
batch = list(itertools.islice(file, IMPORT_BATCH_SIZE))
batch = list(itertools.islice(file, settings.IMPORT_BATCH_SIZE))
if not batch: if not batch:
break break
yield [{'text': line.strip()} for line in batch] yield [{'text': line.strip()} for line in batch]
@ -327,7 +321,7 @@ class CSVParser(FileParser):
columns = next(reader) columns = next(reader)
data = [] data = []
for i, row in enumerate(reader, start=2): for i, row in enumerate(reader, start=2):
if len(data) >= IMPORT_BATCH_SIZE:
if len(data) >= settings.IMPORT_BATCH_SIZE:
yield data yield data
data = [] data = []
if len(row) == len(columns) and len(row) >= 2: if len(row) == len(columns) and len(row) >= 2:
@ -347,7 +341,7 @@ class JSONParser(FileParser):
file = io.TextIOWrapper(file, encoding='utf-8') file = io.TextIOWrapper(file, encoding='utf-8')
data = [] data = []
for i, line in enumerate(file, start=1): for i, line in enumerate(file, start=1):
if len(data) >= IMPORT_BATCH_SIZE:
if len(data) >= settings.IMPORT_BATCH_SIZE:
yield data yield data
data = [] data = []
try: try:

Loading…
Cancel
Save