You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
1.9 KiB

  1. import abc
  2. from typing import List
  3. from auto_labeling_pipeline.labels import Labels
  4. from django.contrib.auth.models import User
  5. from api.models import Project, Example
  6. from label_types.models import CategoryType, SpanType
  7. from labels.models import Label, Category, Span, TextLabel
  8. class LabelCollection(abc.ABC):
  9. label_type = None
  10. model = None
  11. def __init__(self, labels):
  12. self.labels = labels
  13. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  14. mapping = {
  15. c.text: c for c in self.label_type.objects.filter(project=project)
  16. }
  17. annotations = []
  18. for label in self.labels:
  19. if label['label'] not in mapping:
  20. continue
  21. label['example'] = example
  22. label['label'] = mapping[label['label']]
  23. label['user'] = user
  24. annotations.append(self.model(**label))
  25. return annotations
  26. def save(self, project: Project, example: Example, user: User):
  27. labels = self.transform(project, example, user)
  28. labels = self.model.objects.filter_annotatable_labels(labels, project)
  29. self.model.objects.bulk_create(labels)
  30. class Categories(LabelCollection):
  31. label_type = CategoryType
  32. model = Category
  33. class Spans(LabelCollection):
  34. label_type = SpanType
  35. model = Span
  36. class Texts(LabelCollection):
  37. model = TextLabel
  38. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  39. annotations = []
  40. for label in self.labels:
  41. label['example'] = example
  42. label['user'] = user
  43. annotations.append(self.model(**label))
  44. return annotations
  45. def create_labels(task_type: str, labels: Labels) -> LabelCollection:
  46. return {
  47. 'Category': Categories,
  48. 'Span': Spans,
  49. 'Text': Texts
  50. }[task_type](labels.dict())