Browse Source

Support project clone in backend

pull/2204/head
Hironsan 2 years ago
parent
commit
09c6fe1c80
4 changed files with 112 additions and 3 deletions
  1. 49
      backend/projects/models.py
  2. 48
      backend/projects/tests/test_project.py
  3. 3
      backend/projects/urls.py
  4. 15
      backend/projects/views/project.py

49
backend/projects/models.py

@ -1,4 +1,5 @@
import abc import abc
import uuid
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
@ -60,6 +61,54 @@ class Project(PolymorphicModel):
def is_text_project(self) -> bool: def is_text_project(self) -> bool:
return False return False
def clone(self) -> "Project":
"""Clone the project.
See https://docs.djangoproject.com/en/4.2/topics/db/queries/#copying-model-instances
Returns:
The cloned project.
"""
project = Project.objects.get(pk=self.pk)
project.pk = None
project.id = None
project._state.adding = True
project.save()
def bulk_clone(queryset: models.QuerySet, field_initializers: dict = None):
"""Clone the queryset.
Args:
queryset: The queryset to clone.
field_initializers: The field initializers.
"""
if field_initializers is None:
field_initializers = {}
items = []
for item in queryset:
item.id = None
item.pk = None
for field, value_or_callable in field_initializers.items():
if callable(value_or_callable):
value_or_callable = value_or_callable()
setattr(item, field, value_or_callable)
item.project = project
item._state.adding = True
items.append(item)
queryset.model.objects.bulk_create(items)
bulk_clone(self.role_mappings.all())
bulk_clone(self.tags.all())
# clone examples
bulk_clone(self.examples.all(), field_initializers={"uuid": uuid.uuid4})
# clone label types
bulk_clone(self.categorytype_set.all())
bulk_clone(self.spantype_set.all())
bulk_clone(self.relationtype_set.all())
return project
def __str__(self): def __str__(self):
return self.name return self.name

48
backend/projects/tests/test_project.py

@ -1,9 +1,12 @@
from django.conf import settings from django.conf import settings
from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from api.tests.utils import CRUDMixin from api.tests.utils import CRUDMixin
from projects.models import Member
from examples.tests.utils import make_doc
from label_types.tests.utils import make_label
from projects.models import DOCUMENT_CLASSIFICATION, Member, Project
from projects.tests.utils import prepare_project from projects.tests.utils import prepare_project
from roles.tests.utils import create_default_roles from roles.tests.utils import create_default_roles
from users.tests.utils import make_user from users.tests.utils import make_user
@ -124,3 +127,46 @@ class TestProjectDetailAPI(CRUDMixin):
def test_denies_non_member_to_delete_project(self): def test_denies_non_member_to_delete_project(self):
self.assert_delete(self.non_member, status.HTTP_403_FORBIDDEN) self.assert_delete(self.non_member, status.HTTP_403_FORBIDDEN)
class TestProjectModel(TestCase):
def setUp(self):
self.project = prepare_project().item
def test_clone_project(self):
project = self.project.clone()
self.assertNotEqual(project.id, self.project.id)
self.assertEqual(project.name, self.project.name)
self.assertEqual(project.role_mappings.count(), self.project.role_mappings.count())
class TestCloneProject(CRUDMixin):
task = DOCUMENT_CLASSIFICATION
view_name = "annotation_list"
@classmethod
def setUpTestData(cls):
project = prepare_project(task=DOCUMENT_CLASSIFICATION)
cls.project = project.item
cls.user = project.admin
make_doc(cls.project)
cls.category_type = make_label(cls.project)
cls.url = reverse(viewname="clone_project", args=[cls.project.id])
def test_clone_project(self):
response = self.assert_create(self.user, status.HTTP_201_CREATED)
project = Project.objects.get(id=response.data["id"])
# assert project
self.assertNotEqual(project.id, self.project.id)
self.assertEqual(project.name, self.project.name)
# assert category type
category_type = project.categorytype_set.first()
self.assertEqual(category_type.text, self.category_type.text)
# assert example
example = self.project.examples.first()
cloned_example = project.examples.first()
self.assertEqual(example.text, cloned_example.text)

3
backend/projects/urls.py

@ -1,7 +1,7 @@
from django.urls import path from django.urls import path
from .views.member import MemberDetail, MemberList, MyRole from .views.member import MemberDetail, MemberList, MyRole
from .views.project import ProjectDetail, ProjectList
from .views.project import CloneProject, ProjectDetail, ProjectList
from .views.tag import TagDetail, TagList from .views.tag import TagDetail, TagList
urlpatterns = [ urlpatterns = [
@ -11,5 +11,6 @@ urlpatterns = [
path(route="projects/<int:project_id>/tags", view=TagList.as_view(), name="tag_list"), path(route="projects/<int:project_id>/tags", view=TagList.as_view(), name="tag_list"),
path(route="projects/<int:project_id>/tags/<int:tag_id>", view=TagDetail.as_view(), name="tag_detail"), path(route="projects/<int:project_id>/tags/<int:tag_id>", view=TagDetail.as_view(), name="tag_detail"),
path(route="projects/<int:project_id>/members", view=MemberList.as_view(), name="member_list"), path(route="projects/<int:project_id>/members", view=MemberList.as_view(), name="member_list"),
path(route="projects/<int:project_id>/clone", view=CloneProject.as_view(), name="clone_project"),
path(route="projects/<int:project_id>/members/<int:member_id>", view=MemberDetail.as_view(), name="member_detail"), path(route="projects/<int:project_id>/members/<int:member_id>", view=MemberDetail.as_view(), name="member_detail"),
] ]

15
backend/projects/views/project.py

@ -1,6 +1,8 @@
from django.conf import settings from django.conf import settings
from django.db import transaction
from django.shortcuts import get_object_or_404
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import filters, generics, status
from rest_framework import filters, generics, status, views
from rest_framework.permissions import IsAdminUser, IsAuthenticated from rest_framework.permissions import IsAdminUser, IsAuthenticated
from rest_framework.response import Response from rest_framework.response import Response
@ -52,3 +54,14 @@ class ProjectDetail(generics.RetrieveUpdateDestroyAPIView):
serializer_class = ProjectPolymorphicSerializer serializer_class = ProjectPolymorphicSerializer
lookup_url_kwarg = "project_id" lookup_url_kwarg = "project_id"
permission_classes = [IsAuthenticated & (IsProjectAdmin | IsProjectStaffAndReadOnly)] permission_classes = [IsAuthenticated & (IsProjectAdmin | IsProjectStaffAndReadOnly)]
class CloneProject(views.APIView):
permission_classes = [IsAuthenticated & IsProjectAdmin]
@transaction.atomic
def post(self, request, *args, **kwargs):
project = get_object_or_404(Project, pk=self.kwargs["project_id"])
cloned_project = project.clone()
serializer = ProjectPolymorphicSerializer(cloned_project)
return Response(serializer.data, status=status.HTTP_201_CREATED)
Loading…
Cancel
Save