You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

207 lines
7.3 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. import csv
  2. import json
  3. from io import TextIOWrapper
  4. import itertools as it
  5. import logging
  6. from django.contrib.auth.views import LoginView as BaseLoginView
  7. from django.urls import reverse
  8. from django.http import HttpResponse, HttpResponseRedirect
  9. from django.shortcuts import get_object_or_404
  10. from django.views import View
  11. from django.views.generic import TemplateView, CreateView
  12. from django.views.generic.list import ListView
  13. from django.contrib.auth.mixins import LoginRequiredMixin
  14. from django.contrib import messages
  15. from .permissions import SuperUserMixin
  16. from .forms import ProjectForm
  17. from .models import Document, Project
  18. from app import settings
  19. logger = logging.getLogger(__name__)
  20. class IndexView(TemplateView):
  21. template_name = 'index.html'
  22. class ProjectView(LoginRequiredMixin, TemplateView):
  23. def get_template_names(self):
  24. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  25. return [project.get_template_name()]
  26. class ProjectsView(LoginRequiredMixin, CreateView):
  27. form_class = ProjectForm
  28. template_name = 'projects.html'
  29. class DatasetView(SuperUserMixin, LoginRequiredMixin, ListView):
  30. template_name = 'admin/dataset.html'
  31. paginate_by = 5
  32. def get_queryset(self):
  33. project = get_object_or_404(Project, pk=self.kwargs['project_id'])
  34. return project.documents.all()
  35. class LabelView(SuperUserMixin, LoginRequiredMixin, TemplateView):
  36. template_name = 'admin/label.html'
  37. class StatsView(SuperUserMixin, LoginRequiredMixin, TemplateView):
  38. template_name = 'admin/stats.html'
  39. class GuidelineView(SuperUserMixin, LoginRequiredMixin, TemplateView):
  40. template_name = 'admin/guideline.html'
  41. class DataUpload(SuperUserMixin, LoginRequiredMixin, TemplateView):
  42. template_name = 'admin/dataset_upload.html'
  43. class ImportFileError(Exception):
  44. def __init__(self, message):
  45. self.message = message
  46. def extract_metadata_csv(self, row, text_col, header_without_text):
  47. vals_without_text = [val for i, val in enumerate(row) if i != text_col]
  48. return json.dumps(dict(zip(header_without_text, vals_without_text)))
  49. def csv_to_documents(self, project, file, text_key='text'):
  50. form_data = TextIOWrapper(file, encoding='utf-8')
  51. reader = csv.reader(form_data)
  52. maybe_header = next(reader)
  53. if maybe_header:
  54. if text_key in maybe_header:
  55. text_col = maybe_header.index(text_key)
  56. elif len(maybe_header) == 1:
  57. reader = it.chain([maybe_header], reader)
  58. text_col = 0
  59. else:
  60. raise DataUpload.ImportFileError("CSV file must have either a title with \"text\" column or have only one column ")
  61. header_without_text = [title for i, title in enumerate(maybe_header)
  62. if i != text_col]
  63. return (
  64. Document(
  65. text=row[text_col],
  66. metadata=self.extract_metadata_csv(row, text_col, header_without_text),
  67. project=project
  68. )
  69. for row in reader
  70. )
  71. else:
  72. return []
  73. def extract_metadata_json(self, entry, text_key):
  74. copy = entry.copy()
  75. del copy[text_key]
  76. return json.dumps(copy)
  77. def json_to_documents(self, project, file, text_key='text'):
  78. parsed_entries = (json.loads(line) for line in file)
  79. return (
  80. Document(text=entry[text_key], metadata=self.extract_metadata_json(entry, text_key), project=project)
  81. for entry in parsed_entries
  82. )
  83. def post(self, request, *args, **kwargs):
  84. project = get_object_or_404(Project, pk=kwargs.get('project_id'))
  85. import_format = request.POST['format']
  86. try:
  87. file = request.FILES['file'].file
  88. documents = []
  89. if import_format == 'csv':
  90. documents = self.csv_to_documents(project, file)
  91. elif import_format == 'json':
  92. documents = self.json_to_documents(project, file)
  93. batch_size = settings.IMPORT_BATCH_SIZE
  94. while True:
  95. batch = list(it.islice(documents, batch_size))
  96. if not batch:
  97. break
  98. Document.objects.bulk_create(batch, batch_size=batch_size)
  99. return HttpResponseRedirect(reverse('dataset', args=[project.id]))
  100. except DataUpload.ImportFileError as e:
  101. messages.add_message(request, messages.ERROR, e.message)
  102. return HttpResponseRedirect(reverse('upload', args=[project.id]))
  103. except Exception as e:
  104. logger.exception(e)
  105. messages.add_message(request, messages.ERROR, 'Something went wrong')
  106. return HttpResponseRedirect(reverse('upload', args=[project.id]))
  107. class DataDownload(SuperUserMixin, LoginRequiredMixin, TemplateView):
  108. template_name = 'admin/dataset_download.html'
  109. class DataDownloadFile(SuperUserMixin, LoginRequiredMixin, View):
  110. def get(self, request, *args, **kwargs):
  111. project_id = self.kwargs['project_id']
  112. project = get_object_or_404(Project, pk=project_id)
  113. docs = project.get_documents(is_null=False).distinct()
  114. export_format = request.GET.get('format')
  115. filename = '_'.join(project.name.lower().split())
  116. try:
  117. if export_format == 'csv':
  118. response = self.get_csv(filename, docs)
  119. elif export_format == 'json':
  120. response = self.get_json(filename, docs)
  121. return response
  122. except Exception as e:
  123. logger.exception(e)
  124. messages.add_message(request, messages.ERROR, "Something went wrong")
  125. return HttpResponseRedirect(reverse('download', args=[project.id]))
  126. def get_csv(self, filename, docs):
  127. response = HttpResponse(content_type='text/csv')
  128. response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename)
  129. writer = csv.writer(response)
  130. for d in docs:
  131. writer.writerows(d.to_csv())
  132. return response
  133. def get_json(self, filename, docs):
  134. response = HttpResponse(content_type='text/json')
  135. response['Content-Disposition'] = 'attachment; filename="{}.json"'.format(filename)
  136. for d in docs:
  137. dump = json.dumps(d.to_json(), ensure_ascii=False)
  138. response.write(dump + '\n') # write each json object end with a newline
  139. return response
  140. class LoginView(BaseLoginView):
  141. template_name = 'login.html'
  142. redirect_authenticated_user = True
  143. extra_context = {
  144. 'github_login': bool(settings.SOCIAL_AUTH_GITHUB_KEY),
  145. 'aad_login': bool(settings.SOCIAL_AUTH_AZUREAD_TENANT_OAUTH2_TENANT_ID),
  146. }
  147. def get_context_data(self, **kwargs):
  148. context = super(LoginView, self).get_context_data(**kwargs)
  149. context['social_login_enabled'] = any(value for key, value in context.items()
  150. if key.endswith('_login'))
  151. return context
  152. class DemoTextClassification(TemplateView):
  153. template_name = 'demo/demo_text_classification.html'
  154. class DemoNamedEntityRecognition(TemplateView):
  155. template_name = 'demo/demo_named_entity.html'
  156. class DemoTranslation(TemplateView):
  157. template_name = 'demo/demo_translation.html'