Browse Source

Add a feature to upload intent detection and slot filling data

pull/1619/head
Hironsan 2 years ago
parent
commit
d8ba13ee9f
6 changed files with 89 additions and 23 deletions
  1. 5
      backend/api/tests/data/intent/example.jsonl
  2. 31
      backend/api/tests/test_tasks.py
  3. 27
      backend/api/views/statistics.py
  4. 8
      backend/api/views/upload/catalog.py
  5. 5
      backend/api/views/upload/examples.py
  6. 36
      backend/api/views/upload/factories.py

5
backend/api/tests/data/intent/example.jsonl

@ -0,0 +1,5 @@
{"text": "exampleA", "entities": [[0, 1, "LOC"]], "cats": ["positive"]}
{"text": "exampleB", "cats": ["positive"]}
{"text": "exampleC", "entities": [[0, 1, "LOC"]]}
{"text": "exampleD"}
{"entities": [[0, 1, "LOC"]], "cats": ["positive"]}

31
backend/api/tests/test_tasks.py

@ -2,8 +2,10 @@ import pathlib
from django.test import TestCase from django.test import TestCase
from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING,
Category, DocType, Example, Span, SpanType)
from ..models import (DOCUMENT_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, Category, DocType, Example, Span,
SpanType)
from ..tasks import ingest_data from ..tasks import ingest_data
from .api.utils import prepare_project from .api.utils import prepare_project
@ -236,3 +238,28 @@ class TestIngestSeq2seqData(TestIngestData):
] ]
self.ingest_data(filename, file_format) self.ingest_data(filename, file_format)
self.assert_examples(dataset) self.assert_examples(dataset)
class TextIngestIntentDetectionAndSlotFillingData(TestIngestData):
task = INTENT_DETECTION_AND_SLOT_FILLING
def assert_examples(self, dataset):
self.assertEqual(Example.objects.count(), len(dataset))
for text, expected_labels in dataset:
example = Example.objects.get(text=text)
cats = set(cat.label.text for cat in example.categories.all())
entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()]
self.assertEqual(cats, set(expected_labels['cats']))
self.assertEqual(entities, expected_labels['entities'])
def test_entities_and_cats(self):
filename = 'intent/example.jsonl'
file_format = 'JSONL'
dataset = [
('exampleA', {'cats': ['positive'], 'entities': [(0, 1, 'LOC')]}),
('exampleB', {'cats': ['positive'], 'entities': []}),
('exampleC', {'cats': [], 'entities': [(0, 1, 'LOC')]}),
('exampleD', {'cats': [], 'entities': []}),
]
self.ingest_data(filename, file_format)
self.assert_examples(dataset)

27
backend/api/views/statistics.py

@ -7,8 +7,8 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from ..models import (Annotation, Category, ExampleState, Project, RoleMapping,
Span)
from ..models import (Annotation, Category, Example, ExampleState, Project,
RoleMapping, Span)
from ..permissions import IsInProjectReadOnlyOrAdmin from ..permissions import IsInProjectReadOnlyOrAdmin
@ -29,7 +29,7 @@ class StatisticsAPI(APIView):
response['user_label'] = user_count response['user_label'] = user_count
if not include or 'total' in include or 'remaining' in include or 'user' in include: if not include or 'total' in include or 'remaining' in include or 'user' in include:
progress = self.progress(project=p)
progress = self.progress()
response.update(progress) response.update(progress)
if not include or 'confirmed_count' in include: if not include or 'confirmed_count' in include:
@ -41,8 +41,8 @@ class StatisticsAPI(APIView):
return Response(response) return Response(response)
def progress(self, project):
examples = project.examples.values('id')
def progress(self):
examples = Example.objects.filter(project=self.kwargs['project_id']).values('id')
total = examples.count() total = examples.count()
done = ExampleState.objects.count_done(examples) done = ExampleState.objects.count_done(examples)
done_by_user = ExampleState.objects.count_user(examples) done_by_user = ExampleState.objects.count_user(examples)
@ -50,8 +50,6 @@ class StatisticsAPI(APIView):
return {'total': total, 'remaining': remaining, 'user': done_by_user} return {'total': total, 'remaining': remaining, 'user': done_by_user}
def label_per_data(self, project): def label_per_data(self, project):
# annotation_class = project.get_annotation_class()
# return annotation_class.objects.get_label_per_data(project=project)
return {}, {} return {}, {}
def confirmed_count(self, project): def confirmed_count(self, project):
@ -70,16 +68,21 @@ class StatisticsAPI(APIView):
return confirmed_count return confirmed_count
class ProgressAPI(APIView):
def get(self, request, *args, **kwargs):
examples = Example.objects.filter(project=self.kwargs['project_id']).values('id')
total = examples.count()
done = ExampleState.objects.count_done(examples)
return {'total': total, 'remaining': total - done}
class LabelFrequency(abc.ABC, APIView): class LabelFrequency(abc.ABC, APIView):
permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin] permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin]
model = Annotation model = Annotation
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.calc_label_frequency()
def calc_label_frequency(self):
project = get_object_or_404(Project, pk=self.kwargs['project_id'])
examples = project.examples.values('id')
examples = Example.objects.filter(project=self.kwargs['project_id']).values('id')
return self.model.objects.calc_label_frequency(examples) return self.model.objects.calc_label_frequency(examples)

8
backend/api/views/upload/catalog.py

@ -4,7 +4,8 @@ from typing import Dict, List, Type
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Literal from typing_extensions import Literal
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT) SEQUENCE_LABELING, SPEECH2TEXT)
from . import examples from . import examples
@ -244,6 +245,11 @@ Options.register(SEQ2SEQ, JSON, OptionColumn, examples.Text_JSON)
Options.register(SEQ2SEQ, JSONL, OptionColumn, examples.Text_JSONL) Options.register(SEQ2SEQ, JSONL, OptionColumn, examples.Text_JSONL)
Options.register(SEQ2SEQ, Excel, OptionColumn, examples.Text_CSV) Options.register(SEQ2SEQ, Excel, OptionColumn, examples.Text_CSV)
# Intent detection and slof filling
Options.register(INTENT_DETECTION_AND_SLOT_FILLING, TextFile, OptionEncoding, examples.Generic_TextFile)
Options.register(INTENT_DETECTION_AND_SLOT_FILLING, TextLine, OptionEncoding, examples.Generic_TextLine)
Options.register(INTENT_DETECTION_AND_SLOT_FILLING, JSONL, OptionNone, examples.IDSF_JSONL)
# Image classification # Image classification
Options.register(IMAGE_CLASSIFICATION, ImageFile, OptionNone, examples.Generic_ImageFile) Options.register(IMAGE_CLASSIFICATION, ImageFile, OptionNone, examples.Generic_ImageFile)

5
backend/api/views/upload/examples.py

@ -92,3 +92,8 @@ lamb O
Peter B-PER Peter B-PER
Blackburn I-PER Blackburn I-PER
""" """
IDSF_JSONL = """
{"text": "Find a flight from Memphis to Tacoma", "entities": [[0, 26, "City"], [30, 36, "City"]], "cats": ["flight"]}
{"text": "I want to know what airports are in Los Angeles", "entities": [[36, 47, "City"]], "cats": ["airport"]}
"""

36
backend/api/views/upload/factories.py

@ -1,10 +1,16 @@
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION,
INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ,
SEQUENCE_LABELING, SPEECH2TEXT) SEQUENCE_LABELING, SPEECH2TEXT)
from . import builders, catalog, cleaners, data, label, parsers, readers from . import builders, catalog, cleaners, data, label, parsers, readers
def get_data_class(project_type: str): def get_data_class(project_type: str):
text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ]
text_projects = [
DOCUMENT_CLASSIFICATION,
SEQUENCE_LABELING,
SEQ2SEQ,
INTENT_DETECTION_AND_SLOT_FILLING
]
if project_type in text_projects: if project_type in text_projects:
return data.TextData return data.TextData
else: else:
@ -59,14 +65,28 @@ def create_bulder(project, **kwargs):
name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN, name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN,
value_class=get_data_class(project.project_type) value_class=get_data_class(project.project_type)
) )
# Todo: If project is EntityClassification,
# If project is intent detection and slot filling,
# column names are fixed: entities, cats # column names are fixed: entities, cats
label_column = builders.LabelColumn(
name=kwargs.get('column_label') or readers.DEFAULT_LABEL_COLUMN,
value_class=get_label_class(project.project_type)
)
if project.project_type == INTENT_DETECTION_AND_SLOT_FILLING:
label_columns = [
builders.LabelColumn(
name='cats',
value_class=label.CategoryLabel
),
builders.LabelColumn(
name='entities',
value_class=label.SpanLabel
)
]
else:
label_columns = [
builders.LabelColumn(
name=kwargs.get('column_label') or readers.DEFAULT_LABEL_COLUMN,
value_class=get_label_class(project.project_type)
)
]
builder = builders.ColumnBuilder( builder = builders.ColumnBuilder(
data_column=data_column, data_column=data_column,
label_columns=[label_column]
label_columns=label_columns
) )
return builder return builder
Loading…
Cancel
Save