From 4a4749d4e018bdbc55a9946645ef67031937958e Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 7 Feb 2022 10:03:03 +0900 Subject: [PATCH 1/2] Add mypy as a dependency --- Pipfile | 4 +++ Pipfile.lock | 33 +++++++++++++++++++- backend/api/tests/utils.py | 4 ++- backend/auto_labeling/pipeline/labels.py | 6 ++-- backend/data_export/pipeline/factories.py | 4 +-- backend/data_export/pipeline/repositories.py | 21 +++++++------ backend/data_export/pipeline/writers.py | 8 ++--- backend/data_import/pipeline/builders.py | 6 ++-- backend/data_import/pipeline/cleaners.py | 6 ++-- backend/data_import/pipeline/labels.py | 4 +-- backend/data_import/pipeline/parsers.py | 8 ++--- backend/data_import/pipeline/readers.py | 2 +- backend/data_import/pipeline/writers.py | 6 ++-- backend/projects/permissions.py | 4 +-- backend/projects/serializers.py | 6 ++-- backend/projects/tests/utils.py | 2 +- pyproject.toml | 6 ++++ 17 files changed, 87 insertions(+), 43 deletions(-) diff --git a/Pipfile b/Pipfile index 560d5eca..e7b36c78 100644 --- a/Pipfile +++ b/Pipfile @@ -19,6 +19,9 @@ watchdog = "*" argh = "*" black = "*" pyproject-flake8 = "*" +types-chardet = "*" +types-requests = "*" +types-waitress = "*" [packages] django = "~=3.2" @@ -63,6 +66,7 @@ python_version = "3.8" isort = "isort api -c --skip migrations" flake8 = "pflake8 --filename \"*.py\" --extend-exclude \"*/migrations\"" black = "black --check ." +mypy = "mypy ." wait_for_db = "python manage.py wait_for_db" test = "python manage.py test --pattern=\"test*.py\"" migrate = "python manage.py migrate" diff --git a/Pipfile.lock b/Pipfile.lock index a45c5785..edf101ce 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "1ec1af8bf3969c0798488c46364a48661b7789b66b2682c23f858ca760a4c89c" + "sha256": "b73f92aa76c8d2a6fdc0927b1853c46c27ae9508d5c2173b6e79823fcd1bf634" }, "pipfile-spec": 6, "requires": { @@ -1680,6 +1680,37 @@ ], "version": "==2.0.0" }, + "types-chardet": { + "hashes": [ + "sha256:519850a12ab0009f3ec5bdca35ce1c0de4eb4a67a2110aa206386e6219b3ecd8", + "sha256:8990a86d4c7cfa6c6c5889fc49e456e477851e75b5adb396d42ae106d0ae02ea" + ], + "index": "pypi", + "version": "==4.0.3" + }, + "types-requests": { + "hashes": [ + "sha256:8ec9f5f84adc6f579f53943312c28a84e87dc70201b54f7c4fbc7d22ecfa8a3e", + "sha256:c2f4e4754d07ca0a88fd8a89bbc6c8a9f90fb441f9c9b572fd5c484f04817486" + ], + "index": "pypi", + "version": "==2.27.8" + }, + "types-urllib3": { + "hashes": [ + "sha256:4a54f6274ab1c80968115634a55fb9341a699492b95e32104a7c513db9fe02e9", + "sha256:abd2d4857837482b1834b4817f0587678dcc531dbc9abe4cde4da28cef3f522c" + ], + "version": "==1.26.9" + }, + "types-waitress": { + "hashes": [ + "sha256:d7843d13487effb0e0774ec294f42ca63ed9f74a9296b47e4e290ddb21a05292", + "sha256:fdd57199a5a7b5b3e65973feb137964bd750cdb1af4f7cc7c9d6053342f86ff2" + ], + "index": "pypi", + "version": "==2.0.6" + }, "typing-extensions": { "hashes": [ "sha256:4ca091dea149f945ec56afb48dae714f21e8692ef22a395223bcd328961b6a0e", diff --git a/backend/api/tests/utils.py b/backend/api/tests/utils.py index eec47832..1073ef0e 100644 --- a/backend/api/tests/utils.py +++ b/backend/api/tests/utils.py @@ -1,10 +1,12 @@ +from typing import Any, Dict + from rest_framework import status from rest_framework.test import APITestCase class CRUDMixin(APITestCase): url = "" - data = {} + data: Dict[str, Any] = {} def assert_fetch(self, user=None, expected=status.HTTP_403_FORBIDDEN): if user: diff --git a/backend/auto_labeling/pipeline/labels.py b/backend/auto_labeling/pipeline/labels.py index c2efecab..1010bb2b 100644 --- a/backend/auto_labeling/pipeline/labels.py +++ b/backend/auto_labeling/pipeline/labels.py @@ -6,13 +6,13 @@ from django.contrib.auth.models import User from projects.models import Project from examples.models import Example -from label_types.models import CategoryType, SpanType +from label_types.models import CategoryType, LabelType, SpanType from labels.models import Label, Category, Span, TextLabel class LabelCollection(abc.ABC): - label_type = None - model = None + label_type: LabelType = None + model: Label = None def __init__(self, labels): self.labels = labels diff --git a/backend/data_export/pipeline/factories.py b/backend/data_export/pipeline/factories.py index 9b387d77..fa963ab7 100644 --- a/backend/data_export/pipeline/factories.py +++ b/backend/data_export/pipeline/factories.py @@ -11,7 +11,7 @@ from projects.models import ( from . import catalog, repositories, writers -def create_repository(project) -> repositories.BaseRepository: +def create_repository(project): mapping = { DOCUMENT_CLASSIFICATION: repositories.TextClassificationRepository, SEQUENCE_LABELING: repositories.SequenceLabelingRepository, @@ -22,7 +22,7 @@ def create_repository(project) -> repositories.BaseRepository: } if project.project_type not in mapping: ValueError(f"Invalid project type: {project.project_type}") - repository = mapping.get(project.project_type)(project) + repository = mapping[project.project_type](project) return repository diff --git a/backend/data_export/pipeline/repositories.py b/backend/data_export/pipeline/repositories.py index fd031e62..d5c924bd 100644 --- a/backend/data_export/pipeline/repositories.py +++ b/backend/data_export/pipeline/repositories.py @@ -1,12 +1,14 @@ import abc import itertools from collections import defaultdict -from typing import Dict, Iterator, List +from typing import Dict, Iterator, List, Tuple, Union from projects.models import Project from examples.models import Example from .data import Record +SpanType = Tuple[int, int, str] + class BaseRepository(abc.ABC): def __init__(self, project: Project): @@ -144,16 +146,15 @@ class IntentDetectionSlotFillingRepository(TextRepository): ) def label_per_user(self, doc) -> Dict: - category_per_user = defaultdict(list) - span_per_user = defaultdict(list) - label_per_user = defaultdict(dict) + category_per_user: Dict[str, List[str]] = defaultdict(list) + span_per_user: Dict[str, List[SpanType]] = defaultdict(list) + label_per_user: Dict[str, Dict[str, Union[List[str], List[SpanType]]]] = defaultdict(dict) for a in doc.categories.all(): category_per_user[a.user.username].append(a.label.text) for a in doc.spans.all(): - label = (a.start_offset, a.end_offset, a.label.text) - span_per_user[a.user.username].append(label) - for user, label in category_per_user.items(): - label_per_user[user]["cats"] = label - for user, label in span_per_user.items(): - label_per_user[user]["entities"] = label + span_per_user[a.user.username].append((a.start_offset, a.end_offset, a.label.text)) + for user, cats in category_per_user.items(): + label_per_user[user]["cats"] = cats + for user, span in span_per_user.items(): + label_per_user[user]["entities"] = span return label_per_user diff --git a/backend/data_export/pipeline/writers.py b/backend/data_export/pipeline/writers.py index 72469f56..f1fbc284 100644 --- a/backend/data_export/pipeline/writers.py +++ b/backend/data_export/pipeline/writers.py @@ -58,9 +58,9 @@ class CsvWriter(BaseWriter): def write(self, records: Iterator[Record]) -> str: writers = {} file_handlers = set() - records = list(records) - header = self.create_header(records) - for record in records: + record_list = list(records) + header = self.create_header(record_list) + for record in record_list: filename = os.path.join(self.tmpdir, f"{record.user}.{self.extension}") if filename not in writers: f = open(filename, mode="a", encoding="utf-8") @@ -82,7 +82,7 @@ class CsvWriter(BaseWriter): def create_line(self, record) -> Dict: return {"id": record.id, "data": record.data, "label": "#".join(sorted(record.label)), **record.metadata} - def create_header(self, records: List[Record]) -> Iterable[str]: + def create_header(self, records: List[Record]) -> List[str]: header = ["id", "data", "label"] header += sorted(set(itertools.chain(*[r.metadata.keys() for r in records]))) return header diff --git a/backend/data_import/pipeline/builders.py b/backend/data_import/pipeline/builders.py index 242d48d8..62a8a237 100644 --- a/backend/data_import/pipeline/builders.py +++ b/backend/data_import/pipeline/builders.py @@ -1,6 +1,6 @@ import abc from logging import getLogger -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, Dict, List, Optional, Type from pydantic import ValidationError @@ -10,7 +10,6 @@ from .labels import Label from .readers import Builder, Record logger = getLogger(__name__) -T = TypeVar("T") class PlainBuilder(Builder): @@ -34,7 +33,8 @@ def build_data(row: Dict[Any, Any], name: str, data_class: Type[BaseData], filen class Column(abc.ABC): - def __init__(self, name: str, value_class: Type[T]): + # Todo: need to redesign. + def __init__(self, name: str, value_class: Any): self.name = name self.value_class = value_class diff --git a/backend/data_import/pipeline/cleaners.py b/backend/data_import/pipeline/cleaners.py index 320b3bfe..93b7e48a 100644 --- a/backend/data_import/pipeline/cleaners.py +++ b/backend/data_import/pipeline/cleaners.py @@ -1,7 +1,7 @@ from typing import List from projects.models import Project -from .labels import CategoryLabel, Label, SpanLabel +from .labels import Label, SpanLabel class Cleaner: @@ -21,7 +21,7 @@ class SpanCleaner(Cleaner): super().__init__(project) self.allow_overlapping = getattr(project, "allow_overlapping", False) - def clean(self, labels: List[SpanLabel]) -> List[SpanLabel]: + def clean(self, labels: List[SpanLabel]) -> List[SpanLabel]: # type: ignore if self.allow_overlapping: return labels @@ -44,7 +44,7 @@ class CategoryCleaner(Cleaner): super().__init__(project) self.exclusive = getattr(project, "single_class_classification", False) - def clean(self, labels: List[CategoryLabel]) -> List[CategoryLabel]: + def clean(self, labels: List[Label]) -> List[Label]: if self.exclusive: return labels[:1] else: diff --git a/backend/data_import/pipeline/labels.py b/backend/data_import/pipeline/labels.py index 595c2d82..84f6b395 100644 --- a/backend/data_import/pipeline/labels.py +++ b/backend/data_import/pipeline/labels.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from pydantic import BaseModel, validator @@ -68,7 +68,7 @@ class CategoryLabel(Label): class SpanLabel(Label): - label: Union[str, int] + label: str start_offset: int end_offset: int diff --git a/backend/data_import/pipeline/parsers.py b/backend/data_import/pipeline/parsers.py index 05840158..c9ac6ee6 100644 --- a/backend/data_import/pipeline/parsers.py +++ b/backend/data_import/pipeline/parsers.py @@ -47,7 +47,7 @@ def detect_encoding(filename: str, buffer_size: int = io.DEFAULT_BUFFER_SIZE) -> if detector.done: break if detector.done: - return detector.result["encoding"] + return detector.result["encoding"] or "utf-8" else: return "utf-8" @@ -164,7 +164,7 @@ class JSONParser(Parser): def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding - self._errors = [] + self._errors: List[FileParseException] = [] def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: encoding = decide_encoding(filename, self.encoding) @@ -191,7 +191,7 @@ class JSONLParser(Parser): def __init__(self, encoding: str = DEFAULT_ENCODING, **kwargs): self.encoding = encoding - self._errors = [] + self._errors: List[FileParseException] = [] def parse(self, filename: str) -> Iterator[Dict[Any, Any]]: reader = LineReader(filename, self.encoding) @@ -288,7 +288,7 @@ class CoNLLParser(Parser): self.encoding = encoding self.delimiter = delimiter mapping = {"IOB2": IOB2, "IOE2": IOE2, "IOBES": IOBES, "BILOU": BILOU} - self._errors = [] + self._errors: List[FileParseException] = [] if scheme in mapping: self.scheme = mapping[scheme] else: diff --git a/backend/data_import/pipeline/readers.py b/backend/data_import/pipeline/readers.py index ecf54b82..79fe69aa 100644 --- a/backend/data_import/pipeline/readers.py +++ b/backend/data_import/pipeline/readers.py @@ -100,7 +100,7 @@ class Reader(BaseReader): self.filenames = filenames self.parser = parser self.builder = builder - self._errors = [] + self._errors: List[FileParseException] = [] def __iter__(self) -> Iterator[Record]: for filename in self.filenames: diff --git a/backend/data_import/pipeline/writers.py b/backend/data_import/pipeline/writers.py index a339c325..f5c4f191 100644 --- a/backend/data_import/pipeline/writers.py +++ b/backend/data_import/pipeline/writers.py @@ -9,7 +9,7 @@ from projects.models import Project from examples.models import Example from label_types.models import CategoryType, SpanType from .exceptions import FileParseException -from .readers import BaseReader +from .readers import BaseReader, Record class Writer(abc.ABC): @@ -33,7 +33,7 @@ def group_by_class(instances): class Examples: def __init__(self, buffer_size: int = settings.IMPORT_BATCH_SIZE): self.buffer_size = buffer_size - self.buffer = [] + self.buffer: List[Record] = [] def __len__(self): return len(self.buffer) @@ -85,7 +85,7 @@ class Examples: class BulkWriter(Writer): def __init__(self, batch_size: int): self.examples = Examples(batch_size) - self._errors = [] + self._errors: List[FileParseException] = [] def save(self, reader: BaseReader, project: Project, user, cleaner): it = iter(reader) diff --git a/backend/projects/permissions.py b/backend/projects/permissions.py index 6e43f6c6..4e655c33 100644 --- a/backend/projects/permissions.py +++ b/backend/projects/permissions.py @@ -50,5 +50,5 @@ class IsAnnotationApprover(RolePermission): role_name = settings.ROLE_ANNOTATION_APPROVER -IsProjectMember = IsAnnotator | IsAnnotationApprover | IsProjectAdmin -IsProjectStaffAndReadOnly = IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly +IsProjectMember = IsAnnotator | IsAnnotationApprover | IsProjectAdmin # type: ignore +IsProjectStaffAndReadOnly = IsAnnotatorAndReadOnly | IsAnnotationApproverAndReadOnly # type: ignore diff --git a/backend/projects/serializers.py b/backend/projects/serializers.py index 3484dd11..fe3df7c4 100644 --- a/backend/projects/serializers.py +++ b/backend/projects/serializers.py @@ -49,7 +49,7 @@ class ProjectSerializer(serializers.ModelSerializer): class Meta: model = Project - fields = ( + fields = [ "id", "name", "description", @@ -66,7 +66,7 @@ class ProjectSerializer(serializers.ModelSerializer): "can_define_category", "can_define_span", "tags", - ) + ] read_only_fields = ( "updated_at", "is_text_project", @@ -86,7 +86,7 @@ class TextClassificationProjectSerializer(ProjectSerializer): class SequenceLabelingProjectSerializer(ProjectSerializer): class Meta(ProjectSerializer.Meta): model = SequenceLabelingProject - fields = ProjectSerializer.Meta.fields + ("allow_overlapping", "grapheme_mode") + fields = ProjectSerializer.Meta.fields + ["allow_overlapping", "grapheme_mode"] class Seq2seqProjectSerializer(ProjectSerializer): diff --git a/backend/projects/tests/utils.py b/backend/projects/tests/utils.py index a39a5bb0..dcde935d 100644 --- a/backend/projects/tests/utils.py +++ b/backend/projects/tests/utils.py @@ -49,7 +49,7 @@ def assign_user_to_role(project_member, project, role_name): return mapping -def make_project(task: str, users: List[str], roles: List[str] = None, collaborative_annotation=False, **kwargs): +def make_project(task: str, users: List[str], roles: List[str], collaborative_annotation=False, **kwargs): create_default_roles() # create users. diff --git a/pyproject.toml b/pyproject.toml index a30a0a7c..d630de1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,9 @@ max-line-length = 120 max-complexity = 18 ignore = "E203,E266,W503," filename = "backend/*" + +[tool.mypy] +python_version = "3.8" +ignore_missing_imports = true +show_error_codes = true +exclude = "(migrations)|(app/settings.py)" From cf6c48632dc8a893290566fd769889cdfc4342cb Mon Sep 17 00:00:00 2001 From: Hironsan Date: Mon, 7 Feb 2022 10:03:58 +0900 Subject: [PATCH 2/2] Update workflow to include mypy --- .github/workflows/ci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f428407..46d4da8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,6 +35,9 @@ jobs: run: | pipenv run black working-directory: ./backend + - name: mypy + run: | + pipenv run mypy - name: Run tests run: | pipenv run test