From d8ba13ee9f1e41621382064d9e46c47bc38025cd Mon Sep 17 00:00:00 2001 From: Hironsan Date: Thu, 23 Dec 2021 11:31:34 +0900 Subject: [PATCH] Add a feature to upload intent detection and slot filling data --- backend/api/tests/data/intent/example.jsonl | 5 +++ backend/api/tests/test_tasks.py | 31 ++++++++++++++++-- backend/api/views/statistics.py | 27 +++++++++------- backend/api/views/upload/catalog.py | 8 ++++- backend/api/views/upload/examples.py | 5 +++ backend/api/views/upload/factories.py | 36 ++++++++++++++++----- 6 files changed, 89 insertions(+), 23 deletions(-) create mode 100644 backend/api/tests/data/intent/example.jsonl diff --git a/backend/api/tests/data/intent/example.jsonl b/backend/api/tests/data/intent/example.jsonl new file mode 100644 index 00000000..6beeb616 --- /dev/null +++ b/backend/api/tests/data/intent/example.jsonl @@ -0,0 +1,5 @@ +{"text": "exampleA", "entities": [[0, 1, "LOC"]], "cats": ["positive"]} +{"text": "exampleB", "cats": ["positive"]} +{"text": "exampleC", "entities": [[0, 1, "LOC"]]} +{"text": "exampleD"} +{"entities": [[0, 1, "LOC"]], "cats": ["positive"]} diff --git a/backend/api/tests/test_tasks.py b/backend/api/tests/test_tasks.py index 317971ae..7bc326ab 100644 --- a/backend/api/tests/test_tasks.py +++ b/backend/api/tests/test_tasks.py @@ -2,8 +2,10 @@ import pathlib from django.test import TestCase -from ..models import (DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING, - Category, DocType, Example, Span, SpanType) +from ..models import (DOCUMENT_CLASSIFICATION, + INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ, + SEQUENCE_LABELING, Category, DocType, Example, Span, + SpanType) from ..tasks import ingest_data from .api.utils import prepare_project @@ -236,3 +238,28 @@ class TestIngestSeq2seqData(TestIngestData): ] self.ingest_data(filename, file_format) self.assert_examples(dataset) + + +class TextIngestIntentDetectionAndSlotFillingData(TestIngestData): + task = INTENT_DETECTION_AND_SLOT_FILLING + + def assert_examples(self, dataset): + self.assertEqual(Example.objects.count(), len(dataset)) + for text, expected_labels in dataset: + example = Example.objects.get(text=text) + cats = set(cat.label.text for cat in example.categories.all()) + entities = [(span.start_offset, span.end_offset, span.label.text) for span in example.spans.all()] + self.assertEqual(cats, set(expected_labels['cats'])) + self.assertEqual(entities, expected_labels['entities']) + + def test_entities_and_cats(self): + filename = 'intent/example.jsonl' + file_format = 'JSONL' + dataset = [ + ('exampleA', {'cats': ['positive'], 'entities': [(0, 1, 'LOC')]}), + ('exampleB', {'cats': ['positive'], 'entities': []}), + ('exampleC', {'cats': [], 'entities': [(0, 1, 'LOC')]}), + ('exampleD', {'cats': [], 'entities': []}), + ] + self.ingest_data(filename, file_format) + self.assert_examples(dataset) diff --git a/backend/api/views/statistics.py b/backend/api/views/statistics.py index bae6a013..bb913df0 100644 --- a/backend/api/views/statistics.py +++ b/backend/api/views/statistics.py @@ -7,8 +7,8 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView -from ..models import (Annotation, Category, ExampleState, Project, RoleMapping, - Span) +from ..models import (Annotation, Category, Example, ExampleState, Project, + RoleMapping, Span) from ..permissions import IsInProjectReadOnlyOrAdmin @@ -29,7 +29,7 @@ class StatisticsAPI(APIView): response['user_label'] = user_count if not include or 'total' in include or 'remaining' in include or 'user' in include: - progress = self.progress(project=p) + progress = self.progress() response.update(progress) if not include or 'confirmed_count' in include: @@ -41,8 +41,8 @@ class StatisticsAPI(APIView): return Response(response) - def progress(self, project): - examples = project.examples.values('id') + def progress(self): + examples = Example.objects.filter(project=self.kwargs['project_id']).values('id') total = examples.count() done = ExampleState.objects.count_done(examples) done_by_user = ExampleState.objects.count_user(examples) @@ -50,8 +50,6 @@ class StatisticsAPI(APIView): return {'total': total, 'remaining': remaining, 'user': done_by_user} def label_per_data(self, project): - # annotation_class = project.get_annotation_class() - # return annotation_class.objects.get_label_per_data(project=project) return {}, {} def confirmed_count(self, project): @@ -70,16 +68,21 @@ class StatisticsAPI(APIView): return confirmed_count +class ProgressAPI(APIView): + + def get(self, request, *args, **kwargs): + examples = Example.objects.filter(project=self.kwargs['project_id']).values('id') + total = examples.count() + done = ExampleState.objects.count_done(examples) + return {'total': total, 'remaining': total - done} + + class LabelFrequency(abc.ABC, APIView): permission_classes = [IsAuthenticated & IsInProjectReadOnlyOrAdmin] model = Annotation def get(self, request, *args, **kwargs): - return self.calc_label_frequency() - - def calc_label_frequency(self): - project = get_object_or_404(Project, pk=self.kwargs['project_id']) - examples = project.examples.values('id') + examples = Example.objects.filter(project=self.kwargs['project_id']).values('id') return self.model.objects.calc_label_frequency(examples) diff --git a/backend/api/views/upload/catalog.py b/backend/api/views/upload/catalog.py index 57b89d8c..99f23b22 100644 --- a/backend/api/views/upload/catalog.py +++ b/backend/api/views/upload/catalog.py @@ -4,7 +4,8 @@ from typing import Dict, List, Type from pydantic import BaseModel from typing_extensions import Literal -from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, +from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, + INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT) from . import examples @@ -244,6 +245,11 @@ Options.register(SEQ2SEQ, JSON, OptionColumn, examples.Text_JSON) Options.register(SEQ2SEQ, JSONL, OptionColumn, examples.Text_JSONL) Options.register(SEQ2SEQ, Excel, OptionColumn, examples.Text_CSV) +# Intent detection and slof filling +Options.register(INTENT_DETECTION_AND_SLOT_FILLING, TextFile, OptionEncoding, examples.Generic_TextFile) +Options.register(INTENT_DETECTION_AND_SLOT_FILLING, TextLine, OptionEncoding, examples.Generic_TextLine) +Options.register(INTENT_DETECTION_AND_SLOT_FILLING, JSONL, OptionNone, examples.IDSF_JSONL) + # Image classification Options.register(IMAGE_CLASSIFICATION, ImageFile, OptionNone, examples.Generic_ImageFile) diff --git a/backend/api/views/upload/examples.py b/backend/api/views/upload/examples.py index 1dbc337c..f9aca242 100644 --- a/backend/api/views/upload/examples.py +++ b/backend/api/views/upload/examples.py @@ -92,3 +92,8 @@ lamb O Peter B-PER Blackburn I-PER """ + +IDSF_JSONL = """ +{"text": "Find a flight from Memphis to Tacoma", "entities": [[0, 26, "City"], [30, 36, "City"]], "cats": ["flight"]} +{"text": "I want to know what airports are in Los Angeles", "entities": [[36, 47, "City"]], "cats": ["airport"]} +""" diff --git a/backend/api/views/upload/factories.py b/backend/api/views/upload/factories.py index 1050e67a..a022f624 100644 --- a/backend/api/views/upload/factories.py +++ b/backend/api/views/upload/factories.py @@ -1,10 +1,16 @@ -from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, +from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, + INTENT_DETECTION_AND_SLOT_FILLING, SEQ2SEQ, SEQUENCE_LABELING, SPEECH2TEXT) from . import builders, catalog, cleaners, data, label, parsers, readers def get_data_class(project_type: str): - text_projects = [DOCUMENT_CLASSIFICATION, SEQUENCE_LABELING, SEQ2SEQ] + text_projects = [ + DOCUMENT_CLASSIFICATION, + SEQUENCE_LABELING, + SEQ2SEQ, + INTENT_DETECTION_AND_SLOT_FILLING + ] if project_type in text_projects: return data.TextData else: @@ -59,14 +65,28 @@ def create_bulder(project, **kwargs): name=kwargs.get('column_data') or readers.DEFAULT_TEXT_COLUMN, value_class=get_data_class(project.project_type) ) - # Todo: If project is EntityClassification, + # If project is intent detection and slot filling, # column names are fixed: entities, cats - label_column = builders.LabelColumn( - name=kwargs.get('column_label') or readers.DEFAULT_LABEL_COLUMN, - value_class=get_label_class(project.project_type) - ) + if project.project_type == INTENT_DETECTION_AND_SLOT_FILLING: + label_columns = [ + builders.LabelColumn( + name='cats', + value_class=label.CategoryLabel + ), + builders.LabelColumn( + name='entities', + value_class=label.SpanLabel + ) + ] + else: + label_columns = [ + builders.LabelColumn( + name=kwargs.get('column_label') or readers.DEFAULT_LABEL_COLUMN, + value_class=get_label_class(project.project_type) + ) + ] builder = builders.ColumnBuilder( data_column=data_column, - label_columns=[label_column] + label_columns=label_columns ) return builder