From 01b54d78711329f7c91462d53774e7e9cfd096b3 Mon Sep 17 00:00:00 2001 From: Hironsan Date: Fri, 14 May 2021 08:07:47 +0900 Subject: [PATCH] Add file repository --- backend/api/views/download/factory.py | 4 +- backend/api/views/download/repositories.py | 67 ++++++++++++++++++---- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/backend/api/views/download/factory.py b/backend/api/views/download/factory.py index acf205d1..c1c65666 100644 --- a/backend/api/views/download/factory.py +++ b/backend/api/views/download/factory.py @@ -1,6 +1,7 @@ from typing import Type -from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING +from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ, + SEQUENCE_LABELING) from . import catalog, repositories, writer @@ -9,6 +10,7 @@ def create_repository(project) -> repositories.BaseRepository: DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, SEQUENCE_LABELING: repositories.SequenceLabelingRepository, SEQ2SEQ: repositories.Seq2seqRepository, + IMAGE_CLASSIFICATION: repositories.FileRepository, } if project.project_type not in mapping: ValueError(f'Invalid project type: {project.project_type}') diff --git a/backend/api/views/download/repositories.py b/backend/api/views/download/repositories.py index c5152b2b..aa8f09c1 100644 --- a/backend/api/views/download/repositories.py +++ b/backend/api/views/download/repositories.py @@ -3,7 +3,7 @@ import itertools from collections import defaultdict from typing import Dict, Iterator, List -from ...models import Project +from ...models import Document, Project from .data import Record @@ -17,11 +17,56 @@ class BaseRepository(abc.ABC): pass +class FileRepository(BaseRepository): + + def list(self, export_approved=False) -> Iterator[Record]: + examples = self.project.examples.all() + if export_approved: + examples = examples.exclude(annotations_approved_by=None) + + for example in examples: + label_per_user = self.label_per_user(example) + if self.project.collaborative_annotation: + label_per_user = self.reduce_user(label_per_user) + for user, label in label_per_user.items(): + yield Record( + id=example.id, + data=example.filename, + label=label, + user=user, + metadata=example.meta + ) + # todo: + # If there is no label, export the doc with `unknown` user. + # This is a quick solution. + # In the future, the doc without label will be exported + # with the user who approved the doc. + # This means I will allow each user to be able to approve the doc. + if len(label_per_user) == 0: + yield Record( + id=example.id, + data=example.text, + label=[], + user='unknown', + metadata={} + ) + + def label_per_user(self, example) -> Dict: + label_per_user = defaultdict(list) + for a in example.categories.all(): + label_per_user[a.user.username].append(a.label.text) + return label_per_user + + def reduce_user(self, label_per_user: Dict[str, List]): + value = list(itertools.chain(*label_per_user.values())) + return {'all': value} + + class TextRepository(BaseRepository): @property def docs(self): - return self.project.documents.all() + return Document.objects.filter(project=self.project) def list(self, export_approved=False): docs = self.docs @@ -68,13 +113,13 @@ class TextClassificationRepository(TextRepository): @property def docs(self): - return self.project.documents.prefetch_related( - 'doc_annotations__user', 'doc_annotations__label' + return Document.objects.filter(project=self.project).prefetch_related( + 'categories__user', 'categories__label' ) def label_per_user(self, doc) -> Dict: label_per_user = defaultdict(list) - for a in doc.doc_annotations.all(): + for a in doc.categories.all(): label_per_user[a.user.username].append(a.label.text) return label_per_user @@ -83,13 +128,13 @@ class SequenceLabelingRepository(TextRepository): @property def docs(self): - return self.project.documents.prefetch_related( - 'seq_annotations__user', 'seq_annotations__label' + return Document.objects.filter(project=self.project).prefetch_related( + 'spans__user', 'spans__label' ) def label_per_user(self, doc) -> Dict: label_per_user = defaultdict(list) - for a in doc.seq_annotations.all(): + for a in doc.spans.all(): label = (a.start_offset, a.end_offset, a.label.text) label_per_user[a.user.username].append(label) return label_per_user @@ -99,12 +144,12 @@ class Seq2seqRepository(TextRepository): @property def docs(self): - return self.project.documents.prefetch_related( - 'seq2seq_annotations__user' + return Document.objects.filter(project=self.project).prefetch_related( + 'texts__user' ) def label_per_user(self, doc) -> Dict: label_per_user = defaultdict(list) - for a in doc.seq2seq_annotations.all(): + for a in doc.texts.all(): label_per_user[a.user.username].append(a.text) return label_per_user