Browse Source

Add file repository

pull/1370/head
Hironsan 3 years ago
parent
commit
01b54d7871
2 changed files with 59 additions and 12 deletions
  1. 4
      backend/api/views/download/factory.py
  2. 67
      backend/api/views/download/repositories.py

4
backend/api/views/download/factory.py

@ -1,6 +1,7 @@
from typing import Type from typing import Type
from ...models import DOCUMENT_CLASSIFICATION, SEQ2SEQ, SEQUENCE_LABELING
from ...models import (DOCUMENT_CLASSIFICATION, IMAGE_CLASSIFICATION, SEQ2SEQ,
SEQUENCE_LABELING)
from . import catalog, repositories, writer from . import catalog, repositories, writer
@ -9,6 +10,7 @@ def create_repository(project) -> repositories.BaseRepository:
DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository,
SEQUENCE_LABELING: repositories.SequenceLabelingRepository, SEQUENCE_LABELING: repositories.SequenceLabelingRepository,
SEQ2SEQ: repositories.Seq2seqRepository, SEQ2SEQ: repositories.Seq2seqRepository,
IMAGE_CLASSIFICATION: repositories.FileRepository,
} }
if project.project_type not in mapping: if project.project_type not in mapping:
ValueError(f'Invalid project type: {project.project_type}') ValueError(f'Invalid project type: {project.project_type}')

67
backend/api/views/download/repositories.py

@ -3,7 +3,7 @@ import itertools
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterator, List from typing import Dict, Iterator, List
from ...models import Project
from ...models import Document, Project
from .data import Record from .data import Record
@ -17,11 +17,56 @@ class BaseRepository(abc.ABC):
pass pass
class FileRepository(BaseRepository):
def list(self, export_approved=False) -> Iterator[Record]:
examples = self.project.examples.all()
if export_approved:
examples = examples.exclude(annotations_approved_by=None)
for example in examples:
label_per_user = self.label_per_user(example)
if self.project.collaborative_annotation:
label_per_user = self.reduce_user(label_per_user)
for user, label in label_per_user.items():
yield Record(
id=example.id,
data=example.filename,
label=label,
user=user,
metadata=example.meta
)
# todo:
# If there is no label, export the doc with `unknown` user.
# This is a quick solution.
# In the future, the doc without label will be exported
# with the user who approved the doc.
# This means I will allow each user to be able to approve the doc.
if len(label_per_user) == 0:
yield Record(
id=example.id,
data=example.text,
label=[],
user='unknown',
metadata={}
)
def label_per_user(self, example) -> Dict:
label_per_user = defaultdict(list)
for a in example.categories.all():
label_per_user[a.user.username].append(a.label.text)
return label_per_user
def reduce_user(self, label_per_user: Dict[str, List]):
value = list(itertools.chain(*label_per_user.values()))
return {'all': value}
class TextRepository(BaseRepository): class TextRepository(BaseRepository):
@property @property
def docs(self): def docs(self):
return self.project.documents.all()
return Document.objects.filter(project=self.project)
def list(self, export_approved=False): def list(self, export_approved=False):
docs = self.docs docs = self.docs
@ -68,13 +113,13 @@ class TextClassificationRepository(TextRepository):
@property @property
def docs(self): def docs(self):
return self.project.documents.prefetch_related(
'doc_annotations__user', 'doc_annotations__label'
return Document.objects.filter(project=self.project).prefetch_related(
'categories__user', 'categories__label'
) )
def label_per_user(self, doc) -> Dict: def label_per_user(self, doc) -> Dict:
label_per_user = defaultdict(list) label_per_user = defaultdict(list)
for a in doc.doc_annotations.all():
for a in doc.categories.all():
label_per_user[a.user.username].append(a.label.text) label_per_user[a.user.username].append(a.label.text)
return label_per_user return label_per_user
@ -83,13 +128,13 @@ class SequenceLabelingRepository(TextRepository):
@property @property
def docs(self): def docs(self):
return self.project.documents.prefetch_related(
'seq_annotations__user', 'seq_annotations__label'
return Document.objects.filter(project=self.project).prefetch_related(
'spans__user', 'spans__label'
) )
def label_per_user(self, doc) -> Dict: def label_per_user(self, doc) -> Dict:
label_per_user = defaultdict(list) label_per_user = defaultdict(list)
for a in doc.seq_annotations.all():
for a in doc.spans.all():
label = (a.start_offset, a.end_offset, a.label.text) label = (a.start_offset, a.end_offset, a.label.text)
label_per_user[a.user.username].append(label) label_per_user[a.user.username].append(label)
return label_per_user return label_per_user
@ -99,12 +144,12 @@ class Seq2seqRepository(TextRepository):
@property @property
def docs(self): def docs(self):
return self.project.documents.prefetch_related(
'seq2seq_annotations__user'
return Document.objects.filter(project=self.project).prefetch_related(
'texts__user'
) )
def label_per_user(self, doc) -> Dict: def label_per_user(self, doc) -> Dict:
label_per_user = defaultdict(list) label_per_user = defaultdict(list)
for a in doc.seq2seq_annotations.all():
for a in doc.texts.all():
label_per_user[a.user.username].append(a.text) label_per_user[a.user.username].append(a.text)
return label_per_user return label_per_user
Loading…
Cancel
Save