You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

498 lines
17 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import csv
  2. import io
  3. import json
  4. from collections import Counter
  5. from itertools import chain
  6. from django.db import transaction
  7. from django.http import HttpResponse
  8. from django.shortcuts import get_object_or_404
  9. from django_filters.rest_framework import DjangoFilterBackend
  10. from rest_framework import generics, filters, status
  11. from rest_framework.exceptions import ParseError, ValidationError
  12. from rest_framework.permissions import IsAuthenticated, IsAdminUser
  13. from rest_framework.response import Response
  14. from rest_framework.views import APIView
  15. from rest_framework.parsers import MultiPartParser
  16. from .exceptions import FileParseException
  17. from .models import Project, Label, Document
  18. from .models import SequenceAnnotation
  19. from .permissions import IsAdminUserAndWriteOnly, IsProjectUser, IsMyEntity, IsOwnAnnotation
  20. from .serializers import ProjectSerializer, LabelSerializer, DocumentSerializer
  21. from .serializers import SequenceAnnotationSerializer, DocumentAnnotationSerializer, Seq2seqAnnotationSerializer
  22. from .serializers import ProjectPolymorphicSerializer
  23. from .utils import extract_label
  24. class ProjectList(generics.ListCreateAPIView):
  25. queryset = Project.objects.all()
  26. serializer_class = ProjectPolymorphicSerializer
  27. pagination_class = None
  28. permission_classes = (IsAuthenticated, IsAdminUserAndWriteOnly)
  29. def get_queryset(self):
  30. return self.request.user.projects
  31. def create(self, request, *args, **kwargs):
  32. request.data['users'] = [self.request.user.id]
  33. return super().create(request, args, kwargs)
  34. class ProjectDetail(generics.RetrieveUpdateDestroyAPIView):
  35. queryset = Project.objects.all()
  36. serializer_class = ProjectSerializer
  37. lookup_url_kwarg = 'project_id'
  38. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  39. class StatisticsAPI(APIView):
  40. pagination_class = None
  41. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  42. def get(self, request, *args, **kwargs):
  43. p = get_object_or_404(Project, pk=self.kwargs['project_id'])
  44. labels = [label.text for label in p.labels.all()]
  45. users = [user.username for user in p.users.all()]
  46. docs = [doc for doc in p.documents.all()]
  47. nested_labels = [[a.label.text for a in doc.get_annotations()] for doc in docs]
  48. nested_users = [[a.user.username for a in doc.get_annotations()] for doc in docs]
  49. label_count = Counter(chain(*nested_labels))
  50. label_data = [label_count[name] for name in labels]
  51. user_count = Counter(chain(*nested_users))
  52. user_data = [user_count[name] for name in users]
  53. response = {'label': {'labels': labels, 'data': label_data},
  54. 'user': {'users': users, 'data': user_data}}
  55. return Response(response)
  56. class LabelList(generics.ListCreateAPIView):
  57. queryset = Label.objects.all()
  58. serializer_class = LabelSerializer
  59. pagination_class = None
  60. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  61. def get_queryset(self):
  62. queryset = self.queryset.filter(project=self.kwargs['project_id'])
  63. return queryset
  64. def perform_create(self, serializer):
  65. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  66. serializer.save(project=project)
  67. class LabelDetail(generics.RetrieveUpdateDestroyAPIView):
  68. queryset = Label.objects.all()
  69. serializer_class = LabelSerializer
  70. lookup_url_kwarg = 'label_id'
  71. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  72. class DocumentList(generics.ListCreateAPIView):
  73. queryset = Document.objects.all()
  74. serializer_class = DocumentSerializer
  75. filter_backends = (DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter)
  76. search_fields = ('text', )
  77. ordering_fields = ('created_at', 'updated_at', 'doc_annotations__updated_at',
  78. 'seq_annotations__updated_at')
  79. filter_fields = ('doc_annotations__label__id', 'seq_annotations__label__id')
  80. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  81. def get_queryset(self):
  82. queryset = self.queryset.filter(project=self.kwargs['project_id'])
  83. return queryset
  84. def perform_create(self, serializer):
  85. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  86. serializer.save(project=project)
  87. class DocumentDetail(generics.RetrieveUpdateDestroyAPIView):
  88. queryset = Document.objects.all()
  89. serializer_class = DocumentSerializer
  90. lookup_url_kwarg = 'doc_id'
  91. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUserAndWriteOnly)
  92. class EntityList(generics.ListCreateAPIView):
  93. queryset = SequenceAnnotation.objects.all()
  94. serializer_class = SequenceAnnotationSerializer
  95. pagination_class = None
  96. permission_classes = (IsAuthenticated, IsProjectUser)
  97. def get_queryset(self):
  98. queryset = self.queryset.filter(document=self.kwargs['doc_id'],
  99. user=self.request.user)
  100. return queryset
  101. def perform_create(self, serializer):
  102. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  103. serializer.save(document=doc, user=self.request.user)
  104. class EntityDetail(generics.RetrieveUpdateDestroyAPIView):
  105. queryset = SequenceAnnotation.objects.all()
  106. serializer_class = SequenceAnnotationSerializer
  107. lookup_url_kwarg = 'entity_id'
  108. permission_classes = (IsAuthenticated, IsProjectUser, IsMyEntity)
  109. class AnnotationList(generics.ListCreateAPIView):
  110. pagination_class = None
  111. permission_classes = (IsAuthenticated, IsProjectUser)
  112. def get_serializer_class(self):
  113. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  114. self.serializer_class = project.get_annotation_serializer()
  115. return self.serializer_class
  116. def get_queryset(self):
  117. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  118. model = project.get_annotation_class()
  119. self.queryset = model.objects.filter(document=self.kwargs['doc_id'], user=self.request.user)
  120. return self.queryset
  121. def perform_create(self, serializer):
  122. doc = get_object_or_404(Document, pk=self.kwargs['doc_id'])
  123. serializer.save(document=doc, user=self.request.user)
  124. class AnnotationDetail(generics.RetrieveUpdateDestroyAPIView):
  125. lookup_url_kwarg = 'annotation_id'
  126. permission_classes = (IsAuthenticated, IsProjectUser, IsOwnAnnotation)
  127. def get_serializer_class(self):
  128. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  129. self.serializer_class = project.get_annotation_serializer()
  130. return self.serializer_class
  131. def get_queryset(self):
  132. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  133. model = project.get_annotation_class()
  134. self.queryset = model.objects.all()
  135. return self.queryset
  136. class TextUploadAPI(APIView):
  137. """Base API for text upload."""
  138. parser_classes = (MultiPartParser,)
  139. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
  140. def post(self, request, *args, **kwargs):
  141. if 'file' not in request.FILES:
  142. raise ParseError('Empty content')
  143. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  144. handler = project.get_upload_handler(request.data['format'])
  145. handler.handle_uploaded_file(request.FILES['file'], self.request.user)
  146. return Response(status=status.HTTP_201_CREATED)
  147. class TextDownloadAPI(APIView):
  148. """API for text download."""
  149. permission_classes = (IsAuthenticated, IsProjectUser, IsAdminUser)
  150. def get(self, request, *args, **kwargs):
  151. project_id = self.kwargs['project_id']
  152. format = request.query_params.get('q')
  153. project = get_object_or_404(Project, pk=project_id)
  154. handler = project.get_upload_handler(format)
  155. response = handler.render()
  156. return response
  157. class FileHandler(object):
  158. annotation_serializer = None
  159. def __init__(self, project):
  160. self.project = project
  161. @transaction.atomic
  162. def handle_uploaded_file(self, file, user):
  163. raise NotImplementedError()
  164. def parse(self, file):
  165. raise NotImplementedError()
  166. def render(self):
  167. raise NotImplementedError()
  168. def save_doc(self, data):
  169. serializer = DocumentSerializer(data=data)
  170. serializer.is_valid(raise_exception=True)
  171. doc = serializer.save(project=self.project)
  172. return doc
  173. def save_label(self, data):
  174. label = Label.objects.filter(project=self.project, **data).first()
  175. serializer = LabelSerializer(label, data=data)
  176. serializer.is_valid(raise_exception=True)
  177. label = serializer.save(project=self.project)
  178. return label
  179. def save_annotation(self, data, doc, user):
  180. serializer = self.annotation_serializer(data=data)
  181. serializer.is_valid(raise_exception=True)
  182. annotation = serializer.save(document=doc, user=user)
  183. return annotation
  184. class CoNLLHandler(FileHandler):
  185. """Uploads CoNLL format file.
  186. The file format is tab-separated values.
  187. A blank line is required at the end of a sentence.
  188. For example:
  189. ```
  190. EU B-ORG
  191. rejects O
  192. German B-MISC
  193. call O
  194. to O
  195. boycott O
  196. British B-MISC
  197. lamb O
  198. . O
  199. Peter B-PER
  200. Blackburn I-PER
  201. ...
  202. ```
  203. """
  204. annotation_serializer = SequenceAnnotationSerializer
  205. @transaction.atomic
  206. def handle_uploaded_file(self, file, user):
  207. for words, tags in self.parse(file):
  208. start_offset = 0
  209. sent = ' '.join(words)
  210. doc = self.save_doc({'text': sent})
  211. for word, tag in zip(words, tags):
  212. label = extract_label(tag)
  213. label = self.save_label({'text': label})
  214. data = {'start_offset': start_offset,
  215. 'end_offset': start_offset + len(word),
  216. 'label': label.id}
  217. start_offset += len(word) + 1
  218. self.save_annotation(data, doc, user)
  219. def parse(self, file):
  220. words, tags = [], []
  221. for i, line in enumerate(file, start=1):
  222. line = line.decode('utf-8')
  223. line = line.strip()
  224. if line:
  225. try:
  226. word, tag = line.split('\t')
  227. except ValueError:
  228. raise FileParseException(line_num=i, line=line)
  229. words.append(word)
  230. tags.append(tag)
  231. else:
  232. yield words, tags
  233. words, tags = [], []
  234. if len(words) > 0:
  235. yield words, tags
  236. def render(self):
  237. raise ValidationError("This project type doesn't support CoNLL format.")
  238. class PlainTextHandler(FileHandler):
  239. """Uploads plain text.
  240. The file format is as follows:
  241. ```
  242. EU rejects German call to boycott British lamb.
  243. President Obama is speaking at the White House.
  244. ...
  245. ```
  246. """
  247. @transaction.atomic
  248. def handle_uploaded_file(self, file, user):
  249. for text in self.parse(file):
  250. self.save_doc({'text': text})
  251. def parse(self, file):
  252. file = io.TextIOWrapper(file, encoding='utf-8')
  253. for i, line in enumerate(file, start=1):
  254. yield line.strip()
  255. def render(self):
  256. raise ValidationError("You cannot download plain text. Please specify csv or json.")
  257. class CSVHandler(FileHandler):
  258. """Uploads csv file.
  259. The file format is comma separated values.
  260. Column names are required at the top of a file.
  261. For example:
  262. ```
  263. text, label
  264. "EU rejects German call to boycott British lamb.",Politics
  265. "President Obama is speaking at the White House.",Politics
  266. "He lives in Newark, Ohio.",Other
  267. ...
  268. ```
  269. """
  270. def parse(self, file):
  271. file = io.TextIOWrapper(file, encoding='utf-8')
  272. reader = csv.reader(file)
  273. columns = None
  274. for i, row in enumerate(reader, start=1):
  275. if i == 1: # skip header
  276. columns = row
  277. continue
  278. elif len(row) == len(columns) == 2: # text with a label
  279. text, label = row
  280. yield text, label
  281. else:
  282. raise FileParseException(line_num=i, line=row)
  283. def render(self):
  284. raise NotImplementedError()
  285. class CSVClassificationHandler(CSVHandler):
  286. annotation_serializer = DocumentAnnotationSerializer
  287. @transaction.atomic
  288. def handle_uploaded_file(self, file, user):
  289. for text, label in self.parse(file):
  290. doc = self.save_doc({'text': text})
  291. label = self.save_label({'text': label})
  292. self.save_annotation({'label': label.id}, doc, user)
  293. def render(self):
  294. queryset = self.project.documents.all()
  295. serializer = DocumentSerializer(queryset, many=True)
  296. filename = '_'.join(self.project.name.lower().split())
  297. response = HttpResponse(content_type='text/csv')
  298. response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
  299. writer = csv.writer(response)
  300. writer.writerow(['id', 'text', 'label', 'user'])
  301. for d in serializer.data:
  302. for a in d['annotations']:
  303. row = [d['id'], d['text'], a['label'], a['user']]
  304. writer.writerow(row)
  305. return response
  306. class CSVSeq2seqHandler(CSVHandler):
  307. annotation_serializer = Seq2seqAnnotationSerializer
  308. @transaction.atomic
  309. def handle_uploaded_file(self, file, user):
  310. for text, label in self.parse(file):
  311. doc = self.save_doc({'text': text})
  312. self.save_annotation({'text': label}, doc, user)
  313. def render(self):
  314. queryset = self.project.documents.all()
  315. serializer = DocumentSerializer(queryset, many=True)
  316. filename = '_'.join(self.project.name.lower().split())
  317. response = HttpResponse(content_type='text/csv')
  318. response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
  319. writer = csv.writer(response)
  320. writer.writerow(['id', 'text', 'label', 'user'])
  321. for d in serializer.data:
  322. for a in d['annotations']:
  323. row = [d['id'], d['text'], a['text'], a['user']]
  324. writer.writerow(row)
  325. return response
  326. class JsonHandler(FileHandler):
  327. """Uploads jsonl file.
  328. The file format is as follows:
  329. ```
  330. {"text": "example1"}
  331. {"text": "example2"}
  332. ...
  333. ```
  334. """
  335. def parse(self, file):
  336. for i, line in enumerate(file, start=1):
  337. try:
  338. j = json.loads(line)
  339. yield j
  340. except json.decoder.JSONDecodeError:
  341. raise FileParseException(line_num=i, line=line)
  342. def render(self):
  343. queryset = self.project.documents.all()
  344. serializer = DocumentSerializer(queryset, many=True)
  345. filename = '_'.join(self.project.name.lower().split())
  346. response = HttpResponse(content_type='application/json')
  347. response['Content-Disposition'] = 'attachment; filename="{}.jsonl"'.format(filename)
  348. for d in serializer.data:
  349. dump = json.dumps(d, ensure_ascii=False)
  350. response.write(dump + '\n')
  351. return response
  352. class JsonClassificationHandler(JsonHandler):
  353. """Upload jsonl for text classification.
  354. The format is as follows:
  355. ```
  356. {"text": "Python is awesome!", "labels": ["positive"]}
  357. ...
  358. ```
  359. """
  360. annotation_serializer = DocumentAnnotationSerializer
  361. @transaction.atomic
  362. def handle_uploaded_file(self, file, user):
  363. for data in self.parse(file):
  364. doc = self.save_doc(data)
  365. for label in data['labels']:
  366. label = self.save_label({'text': label})
  367. self.save_annotation({'label': label.id}, doc, user)
  368. class JsonLabelingHandler(JsonHandler):
  369. """Upload jsonl for sequence labeling.
  370. The format is as follows:
  371. ```
  372. {"text": "Python is awesome!", "entities": [[0, 6, "Product"],]}
  373. ...
  374. ```
  375. """
  376. annotation_serializer = SequenceAnnotationSerializer
  377. @transaction.atomic
  378. def handle_uploaded_file(self, file, user):
  379. for data in self.parse(file):
  380. doc = self.save_doc(data)
  381. for start_offset, end_offset, label in data['entities']:
  382. label = self.save_label({'text': label})
  383. data = {'label': label.id,
  384. 'start_offset': start_offset,
  385. 'end_offset': end_offset}
  386. self.save_annotation(data, doc, user)
  387. class JsonSeq2seqHandler(JsonHandler):
  388. """Upload jsonl for seq2seq.
  389. The format is as follows:
  390. ```
  391. {"text": "Hello, World!", "labels": ["こんにちは、世界!"]}
  392. ...
  393. ```
  394. """
  395. annotation_serializer = Seq2seqAnnotationSerializer
  396. @transaction.atomic
  397. def handle_uploaded_file(self, file, user):
  398. for data in self.parse(file):
  399. doc = self.save_doc(data)
  400. for label in data['labels']:
  401. self.save_annotation({'text': label}, doc, user)