diff --git a/app/db.sqlite3 b/app/db.sqlite3 index 9416a62c..12b9af51 100644 Binary files a/app/db.sqlite3 and b/app/db.sqlite3 differ diff --git a/app/server/api.py b/app/server/api.py index cb536123..348a38ac 100644 --- a/app/server/api.py +++ b/app/server/api.py @@ -21,19 +21,12 @@ class ProjectViewSet(viewsets.ModelViewSet): permission_classes = (IsAuthenticated, IsAdminUserAndWriteOnly) def get_queryset(self): - user = self.request.user - queryset = self.queryset.filter(users__id__contains=user.id) - - return queryset + return self.request.user.projects @action(methods=['get'], detail=True) def progress(self, request, pk=None): project = self.get_object() - docs = project.get_documents(is_null=True) - total = project.documents.count() - remaining = docs.count() - - return Response({'total': total, 'remaining': remaining}) + return Response(project.get_progress()) class ProjectLabelsAPI(generics.ListCreateAPIView): diff --git a/app/server/models.py b/app/server/models.py index f2c1af38..87116300 100644 --- a/app/server/models.py +++ b/app/server/models.py @@ -22,7 +22,7 @@ class Project(models.Model): guideline = models.TextField() created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) - users = models.ManyToManyField(User) + users = models.ManyToManyField(User, related_name='projects') project_type = models.CharField(max_length=30, choices=PROJECT_CHOICES) def get_absolute_url(self): @@ -31,6 +31,12 @@ class Project(models.Model): def is_type_of(self, project_type): return project_type == self.project_type + def get_progress(self): + docs = self.get_documents(is_null=True) + total = self.documents.count() + remaining = docs.count() + return {'total': total, 'remaining': remaining} + @property def image(self): if self.is_type_of(self.DOCUMENT_CLASSIFICATION): diff --git a/app/server/tests/test_models.py b/app/server/tests/test_models.py index afc0879d..8efce6dc 100644 --- a/app/server/tests/test_models.py +++ b/app/server/tests/test_models.py @@ -10,6 +10,12 @@ class TestProject(TestCase): project = mixer.blend('server.Project') project.is_type_of(project.project_type) + def test_get_progress(self): + project = mixer.blend('server.Project') + res = project.get_progress() + self.assertEqual(res['total'], 0) + self.assertEqual(res['remaining'], 0) + class TestLabel(TestCase):