You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
2.8 KiB

  1. import abc
  2. from itertools import groupby
  3. from typing import Dict, List, Tuple
  4. from .examples import Examples
  5. from .label import Label
  6. from .label_types import LabelTypes
  7. from labels.models import Category as CategoryModel
  8. from labels.models import Label as LabelModel
  9. from labels.models import Relation as RelationModel
  10. from labels.models import Span as SpanModel
  11. from labels.models import TextLabel as TextLabelModel
  12. from projects.models import Project
  13. class Labels(abc.ABC):
  14. label_model = LabelModel
  15. def __init__(self, labels: List[Label], types: LabelTypes):
  16. self.labels = labels
  17. self.types = types
  18. def __len__(self) -> int:
  19. return len(self.labels)
  20. def clean(self, project: Project):
  21. pass
  22. def save_types(self, project: Project):
  23. types = [label.create_type(project) for label in self.labels]
  24. filtered_types = list(filter(None, types))
  25. self.types.save(filtered_types)
  26. self.types.update(project)
  27. def save(self, user, examples: Examples, **kwargs):
  28. labels = [
  29. label.create(user, examples[label.example_uuid], self.types, **kwargs)
  30. for label in self.labels
  31. if label.example_uuid in examples
  32. ]
  33. self.label_model.objects.bulk_create(labels)
  34. class Categories(Labels):
  35. label_model = CategoryModel
  36. def clean(self, project: Project):
  37. exclusive = getattr(project, "single_class_classification", False)
  38. if exclusive:
  39. groups = groupby(self.labels, lambda label: label.example_uuid)
  40. self.labels = [next(group) for _, group in groups]
  41. class Spans(Labels):
  42. label_model = SpanModel
  43. def clean(self, project: Project):
  44. allow_overlapping = getattr(project, "allow_overlapping", False)
  45. if allow_overlapping:
  46. return
  47. spans = []
  48. groups = groupby(self.labels, lambda label: label.example_uuid)
  49. for _, group in groups:
  50. labels = sorted(group)
  51. last_offset = -1
  52. for label in labels:
  53. if getattr(label, "start_offset") >= last_offset:
  54. last_offset = getattr(label, "end_offset")
  55. spans.append(label)
  56. self.labels = spans
  57. @property
  58. def id_to_span(self) -> Dict[Tuple[int, str], SpanModel]:
  59. uuids = [str(span.uuid) for span in self.labels]
  60. spans = SpanModel.objects.filter(uuid__in=uuids)
  61. uuid_to_span = {span.uuid: span for span in spans}
  62. return {(span.id, str(span.example_uuid)): uuid_to_span[span.uuid] for span in self.labels}
  63. class Texts(Labels):
  64. label_model = TextLabelModel
  65. class Relations(Labels):
  66. label_model = RelationModel
  67. def save(self, user, examples: Examples, **kwargs):
  68. id_to_span = kwargs["spans"].id_to_span
  69. super().save(user, examples, id_to_span=id_to_span)