Browse Source

Add a feature to download intent detection and slot filling data

pull/1619/head
Hironsan 2 years ago
parent
commit
6b982a7495
9 changed files with 124 additions and 7 deletions
  1. 6
      backend/api/tests/api/utils.py
  2. 34
      backend/api/tests/download/test_repositories.py
  3. 26
      backend/api/tests/download/test_writer.py
  4. 11
      backend/api/views/download/catalog.py
  5. 4
      backend/api/views/download/data.py
  6. 5
      backend/api/views/download/examples.py
  7. 5
      backend/api/views/download/factory.py
  8. 27
      backend/api/views/download/repositories.py
  9. 13
      backend/api/views/download/writer.py

6
backend/api/tests/api/utils.py

@ -8,7 +8,8 @@ from model_mommy import mommy
from rest_framework import status
from rest_framework.test import APITestCase
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT, Role, RoleMapping)
DATA_DIR = os.path.join(os.path.dirname(__file__), '../data')
@ -63,7 +64,8 @@ def make_project(
SEQUENCE_LABELING: 'SequenceLabelingProject',
SEQ2SEQ: 'Seq2seqProject',
SPEECH2TEXT: 'Speech2TextProject',
IMAGE_CLASSIFICATION: 'ImageClassificationProject'
IMAGE_CLASSIFICATION: 'ImageClassificationProject',
INTENT_DETECTION_AND_SLOT_FILLING: 'IntentDetectionAndSlotFillingProject'
}.get(task, 'Project')
project = mommy.make(
_model=project_model,

34
backend/api/tests/download/test_repositories.py

@ -0,0 +1,34 @@
import unittest
from model_mommy import mommy
from ...models import INTENT_DETECTION_AND_SLOT_FILLING
from ...views.download.repositories import IntentDetectionSlotFillingRepository
from ..api.utils import prepare_project
class TestCSVWriter(unittest.TestCase):
def setUp(self):
self.project = prepare_project(INTENT_DETECTION_AND_SLOT_FILLING)
def test_list(self):
example = mommy.make('Example', project=self.project.item, text='example')
category = mommy.make('Category', example=example, user=self.project.users[0])
span = mommy.make('Span', example=example, user=self.project.users[0], start_offset=0, end_offset=1)
repository = IntentDetectionSlotFillingRepository(self.project.item)
expected = [
{
'data': example.text,
'label': {
'cats': [category.label.text],
'entities': [(span.start_offset, span.end_offset, span.label.text)]
}
}
]
records = list(repository.list())
self.assertEqual(len(records), len(expected))
for record, expect in zip(records, expected):
self.assertEqual(record.data, expect['data'])
self.assertEqual(record.label['cats'], expect['label']['cats'])
self.assertEqual(record.label['entities'], expect['label']['entities'])

26
backend/api/tests/download/test_writer.py

@ -1,8 +1,9 @@
import json
import unittest
from unittest.mock import call, patch
from ...views.download.data import Record
from ...views.download.writer import CsvWriter
from ...views.download.writer import CsvWriter, IntentAndSlotWriter
class TestCSVWriter(unittest.TestCase):
@ -61,3 +62,26 @@ class TestCSVWriter(unittest.TestCase):
call({'id': 2, 'data': 'exampleC', 'label': 'labelC', 'meta': 'secretC'})
]
csv_io.assert_has_calls(calls)
class TestIntentWriter(unittest.TestCase):
def setUp(self):
self.record = Record(
id=0,
data='exampleA',
label={'cats': ['positive'], 'entities': [(0, 1, 'LOC')]},
user='admin',
metadata={}
)
def test_create_line(self):
writer = IntentAndSlotWriter('.')
actual = writer.create_line(self.record)
expected = {
'id': self.record.id,
'text': self.record.data,
'cats': ['positive'],
'entities': [[0, 1, 'LOC']],
}
self.assertEqual(json.loads(actual), expected)

11
backend/api/views/download/catalog.py

@ -4,7 +4,8 @@ from typing import Dict, List, Type
from pydantic import BaseModel
from typing_extensions import Literal
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
from . import examples
@ -39,6 +40,11 @@ class JSONL(Format):
extension = 'jsonl'
class IntentAndSlot(Format):
name = 'JSONL(intent and slot)'
extension = 'jsonl'
class OptionDelimiter(BaseModel):
delimiter: Literal[',', '\t', ';', '|', ' '] = ','
@ -84,6 +90,9 @@ Options.register(SEQ2SEQ, CSV, OptionDelimiter, examples.Text_CSV)
Options.register(SEQ2SEQ, JSON, OptionNone, examples.Text_JSON)
Options.register(SEQ2SEQ, JSONL, OptionNone, examples.Text_JSONL)
# Intent detection and slot filling
Options.register(INTENT_DETECTION_AND_SLOT_FILLING, IntentAndSlot, OptionNone, examples.INTENT_JSONL)
# Image Classification
Options.register(IMAGE_CLASSIFICATION, JSONL, OptionNone, examples.CategoryImageClassification)

4
backend/api/views/download/data.py

@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Union
class Record:
@ -7,7 +7,7 @@ class Record:
def __init__(self,
id: int,
data: str,
label: List[Any],
label: Union[List[Any], Dict[Any, Any]],
user: str,
metadata: Dict[Any, Any]):
self.id = id

5
backend/api/views/download/examples.py

@ -71,3 +71,8 @@ Speech2Text = """
}
]
"""
INTENT_JSONL = """
{"text": "Find a flight from Memphis to Tacoma", "entities": [[0, 26, "City"], [30, 36, "City"]], "cats": ["flight"]}
{"text": "I want to know what airports are in Los Angeles", "entities": [[36, 47, "City"]], "cats": ["airport"]}
"""

5
backend/api/views/download/factory.py

@ -1,6 +1,7 @@
from typing import Type
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT)
from . import catalog, repositories, writer
@ -12,6 +13,7 @@ def create_repository(project) -> repositories.BaseRepository:
SEQ2SEQ: repositories.Seq2seqRepository,
IMAGE_CLASSIFICATION: repositories.FileRepository,
SPEECH2TEXT: repositories.Speech2TextRepository,
INTENT_DETECTION_AND_SLOT_FILLING: repositories.IntentDetectionSlotFillingRepository,
}
if project.project_type not in mapping:
ValueError(f'Invalid project type: {project.project_type}')
@ -25,6 +27,7 @@ def create_writer(format: str) -> Type[writer.BaseWriter]:
catalog.JSON.name: writer.JSONWriter,
catalog.JSONL.name: writer.JSONLWriter,
catalog.FastText.name: writer.FastTextWriter,
catalog.IntentAndSlot.name: writer.IntentAndSlotWriter
}
if format not in mapping:
ValueError(f'Invalid format: {format}')

27
backend/api/views/download/repositories.py

@ -162,3 +162,30 @@ class Seq2seqRepository(TextRepository):
for a in doc.texts.all():
label_per_user[a.user.username].append(a.text)
return label_per_user
class IntentDetectionSlotFillingRepository(TextRepository):
@property
def docs(self):
return Example.objects.filter(project=self.project).prefetch_related(
'categories__user',
'categories__label',
'spans__user',
'spans__label'
)
def label_per_user(self, doc) -> Dict:
category_per_user = defaultdict(list)
span_per_user = defaultdict(list)
label_per_user = defaultdict(dict)
for a in doc.categories.all():
category_per_user[a.user.username].append(a.label.text)
for a in doc.spans.all():
label = (a.start_offset, a.end_offset, a.label.text)
span_per_user[a.user.username].append(label)
for user, label in category_per_user.items():
label_per_user[user]['cats'] = label
for user, label in span_per_user.items():
label_per_user[user]['entities'] = label
return label_per_user

13
backend/api/views/download/writer.py

@ -148,3 +148,16 @@ class FastTextWriter(LineWriter):
line.append(record.data)
line = ' '.join(line)
return line
class IntentAndSlotWriter(LineWriter):
extension = 'jsonl'
def create_line(self, record):
return json.dumps({
'id': record.id,
'text': record.data,
'cats': record.label.get('cats', []),
'entities': record.label.get('entities', []),
**record.metadata
}, ensure_ascii=False)
Loading…
Cancel
Save