You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
1.9 KiB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
  1. import abc
  2. from typing import List, Type
  3. from auto_labeling_pipeline.labels import Labels
  4. from django.contrib.auth.models import User
  5. from examples.models import Example
  6. from label_types.models import CategoryType, LabelType, SpanType
  7. from labels.models import Category, Label, Span, TextLabel
  8. from projects.models import Project
  9. class LabelCollection(abc.ABC):
  10. label_type: Type[LabelType]
  11. model: Type[Label]
  12. def __init__(self, labels):
  13. self.labels = labels
  14. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  15. mapping = {c.text: c for c in self.label_type.objects.filter(project=project)}
  16. annotations = []
  17. for label in self.labels:
  18. if label["label"] not in mapping:
  19. continue
  20. label["example"] = example
  21. label["label"] = mapping[label["label"]]
  22. label["user"] = user
  23. annotations.append(self.model(**label))
  24. return annotations
  25. def save(self, project: Project, example: Example, user: User):
  26. labels = self.transform(project, example, user)
  27. labels = self.model.objects.filter_annotatable_labels(labels, project)
  28. self.model.objects.bulk_create(labels)
  29. class Categories(LabelCollection):
  30. label_type = CategoryType
  31. model = Category
  32. class Spans(LabelCollection):
  33. label_type = SpanType
  34. model = Span
  35. class Texts(LabelCollection):
  36. model = TextLabel
  37. def transform(self, project: Project, example: Example, user: User) -> List[Label]:
  38. annotations = []
  39. for label in self.labels:
  40. label["example"] = example
  41. label["user"] = user
  42. annotations.append(self.model(**label))
  43. return annotations
  44. def create_labels(task_type: str, labels: Labels) -> LabelCollection:
  45. return {"Category": Categories, "Span": Spans, "Text": Texts}[task_type](labels.dict())