diff --git a/pygitguardian/client.py b/pygitguardian/client.py index d847a77f..37742660 100644 --- a/pygitguardian/client.py +++ b/pygitguardian/client.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, Iterator, List, Optional, Union, cast import requests from requests import Response, Session, codes @@ -39,6 +39,7 @@ SecretScanPreferences, ServerMetadata, ) +from .remediation_models import ListSourcesResponse from .sca_models import ( ComputeSCAFilesResult, SCAScanAllOutput, @@ -268,11 +269,23 @@ def request( return response def _url_from_endpoint(self, endpoint: str, version: Optional[str]) -> str: + if endpoint.startswith(self.base_uri): + return endpoint + if version: endpoint = urllib.parse.urljoin(version + "/", endpoint) return urllib.parse.urljoin(self.base_uri + "/", endpoint) + def get_all_pages(self, endpoint: str, **kwargs: Any) -> Iterator[Response]: + last_response = self.request("get", endpoint, **kwargs) + yield last_response + while "next" in last_response.links: + last_response = self.request( + "get", last_response.links["next"]["url"], **kwargs + ) + yield last_response + @property def app_version(self) -> Optional[str]: global VERSIONS @@ -790,3 +803,26 @@ def scan_diff( result = load_detail(response) result.status_code = response.status_code return result + + def list_sources( + self, + params: Union[Dict[str, Any], None] = None, + extra_headers: Optional[Dict[str, str]] = None, + ) -> Union[Detail, ListSourcesResponse]: + result: Union[Detail, ListSourcesResponse] + responses = [] + try: + for response in self.get_all_pages( + endpoint="sources", params=params, extra_headers=extra_headers + ): + if not is_ok(response): + return load_detail(response) + responses.append(response) + except requests.exceptions.ReadTimeout: + result = Detail("The request timed out.") + result.status_code = 504 + else: + sources = [source for response in responses for source in response.json()] + result = ListSourcesResponse.from_dict({"sources": sources}) + result.status_code = responses[-1].status_code + return result diff --git a/pygitguardian/remediation_models.py b/pygitguardian/remediation_models.py new file mode 100644 index 00000000..34825400 --- /dev/null +++ b/pygitguardian/remediation_models.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Type, Union, cast + +import marshmallow_dataclass +from marshmallow import fields + +from pygitguardian.models import Base, BaseSchema, FromDictMixin + + +class SourceHealth(Enum): + """Enum for the different health of a source.""" + + SAFE = "safe" + UNKNOWN = "unknown" + AT_RISK = "at_risk" + + +class SourceCriticality(Enum): + """Enum for the different criticality of a source.""" + + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "MEDIUM" + LOW = "LOW" + UNKNOWN = "unknown" + + +class SourceScanStatus(Enum): + """Enum for the different status of a source scan.""" + + PENDING = "pending" + RUNNING = "running" + CANCELED = "canceled" + FAILED = "failed" + TOO_LARGE = "too_large" + TIMEOUT = "timeout" + FINISHED = "finished" + + +@dataclass +class SourceScan(BaseSchema): + """Represents a scan of a source.""" + + date: str = fields.Date() + status: str = fields.Enum(SourceScanStatus) + failing_reason: Union[str, None] = fields.String(allow_none=True) + commits_scanned: int = fields.Int() + branches_scanned: int = fields.Int() + duration: str = fields.String() + + +@dataclass +class Source(BaseSchema): + """Represents a source.""" + + id: int = fields.Int() + url: str = fields.URL() + type: str = fields.String() + full_name: str = fields.String() + health: str = fields.Enum(SourceHealth) + default_branch: Union[str, None] = fields.String(allow_none=True) + default_branch_head: Union[str, None] = fields.String(allow_none=True) + open_incidents_count: int = fields.Int() + closed_incidents_count: int = fields.Int() + secret_incidents_breakdown: Dict[str, Any] = fields.Dict(keys=fields.Str()) + visibility: str = fields.String() + external_id: str = fields.String() + source_criticality: str = fields.Enum(SourceCriticality) + last_scan: Dict[str, Any] = fields.Dict(keys=fields.Str()) + + +class ListSourcesResponse(Base, FromDictMixin): + """Represents a list of sources.""" + + sources: List[Source] = fields.List(fields.Nested(Source)) + + +ListSourcesResponseSchema = cast( + Type[BaseSchema], + marshmallow_dataclass.class_schema(ListSourcesResponse, base_schema=BaseSchema), +) +ListSourcesResponse.SCHEMA = ListSourcesResponseSchema()