You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
1.9 KiB

  1. import abc
  2. from typing import List
  3. from auto_labeling_pipeline.labels import Labels
  4. from django.contrib.auth.models import User
  5. from api.models import Project
  6. from examples.models import Example
  7. from label_types.models import CategoryType, SpanType
  8. from labels.models import Label, Category, Span, TextLabel
  9. class LabelCollection(abc.ABC):
  10. label_type = None
  11. model = None
  12. def __init__(self, labels):
  13. self.labels = labels
  14. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  15. mapping = {
  16. c.text: c for c in self.label_type.objects.filter(project=project)
  17. }
  18. annotations = []
  19. for label in self.labels:
  20. if label['label'] not in mapping:
  21. continue
  22. label['example'] = example
  23. label['label'] = mapping[label['label']]
  24. label['user'] = user
  25. annotations.append(self.model(**label))
  26. return annotations
  27. def save(self, project: Project, example: Example, user: User):
  28. labels = self.transform(project, example, user)
  29. labels = self.model.objects.filter_annotatable_labels(labels, project)
  30. self.model.objects.bulk_create(labels)
  31. class Categories(LabelCollection):
  32. label_type = CategoryType
  33. model = Category
  34. class Spans(LabelCollection):
  35. label_type = SpanType
  36. model = Span
  37. class Texts(LabelCollection):
  38. model = TextLabel
  39. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  40. annotations = []
  41. for label in self.labels:
  42. label['example'] = example
  43. label['user'] = user
  44. annotations.append(self.model(**label))
  45. return annotations
  46. def create_labels(task_type: str, labels: Labels) -> LabelCollection:
  47. return {
  48. 'Category': Categories,
  49. 'Span': Spans,
  50. 'Text': Texts
  51. }[task_type](labels.dict())