diff --git a/generated/model.py b/generated/model.py index 70275f95..5e6aac4c 100644 --- a/generated/model.py +++ b/generated/model.py @@ -566,3 +566,60 @@ class PaginatedRuleList(BaseModel): next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"]) previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"]) results: List[Rule] + + +class MLPipeline(BaseModel): + """ + An ML pipeline attached to a detector. Contains the pipeline configuration and model binary key. + """ + + id: str = Field(..., description="A unique ID for this pipeline.") + pipeline_config: Optional[str] = Field(None, description="Pipeline configuration string.") + cached_vizlogic_key: Optional[str] = Field(None, description="S3 key of the trained model binary.") + is_active_pipeline: bool = Field(False, description="Whether this is the active (production) pipeline.") + is_edge_pipeline: bool = Field(False, description="Whether this is an edge pipeline.") + is_unclear_pipeline: bool = Field(False, description="Whether this is an unclear-handling pipeline.") + is_oodd_pipeline: bool = Field(False, description="Whether this is an out-of-distribution detection pipeline.") + is_enabled: bool = Field(True, description="Whether this pipeline is enabled.") + created_at: Optional[datetime] = None + trained_at: Optional[datetime] = None + + +class PaginatedMLPipelineList(BaseModel): + count: int = Field(..., examples=[123]) + next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"]) + previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"]) + results: List[MLPipeline] + + +class PrimingGroup(BaseModel): + """ + A PrimingGroup seeds new detectors with a pre-trained model binary so they start with a head start. + """ + + id: str = Field(..., description="A unique ID for this priming group.") + name: str = Field(..., description="A short, descriptive name for the priming group.") + canonical_query: Optional[str] = Field(None, description="Optional canonical query describing this priming group.") + active_pipeline_config: Optional[str] = Field(None, description="Pipeline config used by detectors in this group.") + active_pipeline_base_mlbinary_key: Optional[str] = Field( + None, description="S3 key of the model binary that seeds new detectors in this group." + ) + is_global: bool = Field( + False, + description="If True, this priming group is visible to all users regardless of ownership.", + ) + disable_shadow_pipelines: bool = Field( + False, + description=( + "If True, new detectors in this group will not receive default shadow pipelines, " + "guaranteeing the primed model stays active." + ), + ) + created_at: Optional[datetime] = None + + +class PaginatedPrimingGroupList(BaseModel): + count: int = Field(..., examples=[123]) + next: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=4"]) + previous: Optional[AnyUrl] = Field(None, examples=["http://api.example.org/accounts/?page=2"]) + results: List[PrimingGroup] diff --git a/src/groundlight/client.py b/src/groundlight/client.py index c18ac01f..eac568b8 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -7,6 +7,7 @@ from io import BufferedReader, BytesIO from typing import Any, Callable, List, Optional, Tuple, Union +import requests from groundlight_openapi_client import Configuration from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi from groundlight_openapi_client.api.detectors_api import DetectorsApi @@ -33,9 +34,11 @@ Detector, DetectorGroup, ImageQuery, + MLPipeline, ModeEnum, PaginatedDetectorList, PaginatedImageQueryList, + PrimingGroup, ) from urllib3.exceptions import InsecureRequestWarning from urllib3.util.retry import Retry @@ -1852,3 +1855,160 @@ def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-ar detector_creation_input.mode_configuration = mode_config obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) return Detector.parse_obj(obj.to_dict()) + + # --------------------------------------------------------------------------- + # ML Pipeline methods + # --------------------------------------------------------------------------- + + def list_detector_pipelines(self, detector: Union[str, Detector]) -> List[MLPipeline]: + """ + Lists all ML pipelines associated with a given detector. + + Each detector can have multiple pipelines (active, edge, shadow, etc.). This method returns + all of them, which is useful when selecting a source pipeline to seed a new PrimingGroup. + + **Example usage**:: + + gl = Groundlight() + detector = gl.get_detector("det_abc123") + pipelines = gl.list_detector_pipelines(detector) + for p in pipelines: + if p.is_active_pipeline: + print(f"Active pipeline: {p.id}, config={p.pipeline_config}") + + :param detector: A Detector object or detector ID string. + :return: A list of MLPipeline objects for this detector. + """ + detector_id = detector.id if isinstance(detector, Detector) else detector + url = f"{self.api_client.configuration.host}/v1/detectors/{detector_id}/pipelines" + response = requests.get( + url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl + ) + if response.status_code == 404: + raise NotFoundError(f"Detector '{detector_id}' not found.") + response.raise_for_status() + data = response.json() + return [MLPipeline(**item) for item in data.get("results", [])] + + # --------------------------------------------------------------------------- + # PrimingGroup methods + # --------------------------------------------------------------------------- + + def list_priming_groups(self) -> List[PrimingGroup]: + """ + Lists all PrimingGroups owned by the authenticated user's account. + + PrimingGroups let you seed new detectors with a pre-trained model so they start with a + meaningful head start instead of a blank slate. + + **Example usage**:: + + gl = Groundlight() + groups = gl.list_priming_groups() + for g in groups: + print(f"{g.name}: {g.id}") + + :return: A list of PrimingGroup objects. + """ + url = f"{self.api_client.configuration.host}/v1/priming-groups" + response = requests.get( + url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl + ) + response.raise_for_status() + data = response.json() + return [PrimingGroup(**item) for item in data.get("results", [])] + + def create_priming_group( + self, + name: str, + source_ml_pipeline_id: str, + canonical_query: Optional[str] = None, + disable_shadow_pipelines: bool = False, + ) -> PrimingGroup: + """ + Creates a new PrimingGroup seeded from an existing ML pipeline. + + The trained model binary from the source pipeline is copied into the new PrimingGroup. + Detectors subsequently created with this PrimingGroup's ID will start with that model + already loaded, bypassing the cold-start period. + + **Example usage**:: + + gl = Groundlight() + detector = gl.get_detector("det_abc123") + pipelines = gl.list_detector_pipelines(detector) + active = next(p for p in pipelines if p.is_active_pipeline) + + priming_group = gl.create_priming_group( + name="door-detector-primer", + source_ml_pipeline_id=active.id, + canonical_query="Is the door open?", + disable_shadow_pipelines=True, + ) + print(f"Created priming group: {priming_group.id}") + + :param name: A short, descriptive name for the priming group. + :param source_ml_pipeline_id: The ID of an MLPipeline whose trained model will seed this group. + The pipeline must belong to a detector in your account. + :param canonical_query: An optional description of the visual question this group answers. + :param disable_shadow_pipelines: If True, detectors created in this group will not receive + default shadow pipelines, ensuring the primed model stays active. + :return: The created PrimingGroup object. + """ + url = f"{self.api_client.configuration.host}/v1/priming-groups" + payload: dict = { + "name": name, + "source_ml_pipeline_id": source_ml_pipeline_id, + "disable_shadow_pipelines": disable_shadow_pipelines, + } + if canonical_query is not None: + payload["canonical_query"] = canonical_query + response = requests.post( + url, json=payload, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl + ) + response.raise_for_status() + return PrimingGroup(**response.json()) + + def get_priming_group(self, priming_group_id: str) -> PrimingGroup: + """ + Retrieves a PrimingGroup by ID. + + **Example usage**:: + + gl = Groundlight() + pg = gl.get_priming_group("pgp_abc123") + print(f"Priming group name: {pg.name}") + + :param priming_group_id: The ID of the PrimingGroup to retrieve. + :return: The PrimingGroup object. + """ + url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}" + response = requests.get( + url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl + ) + if response.status_code == 404: + raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.") + response.raise_for_status() + return PrimingGroup(**response.json()) + + def delete_priming_group(self, priming_group_id: str) -> None: + """ + Deletes (soft-deletes) a PrimingGroup owned by the authenticated user. + + This does not delete any detectors that were created using this priming group — + it only removes the priming group itself. Detectors already created remain unaffected. + + **Example usage**:: + + gl = Groundlight() + gl.delete_priming_group("pgp_abc123") + + :param priming_group_id: The ID of the PrimingGroup to delete. + """ + url = f"{self.api_client.configuration.host}/v1/priming-groups/{priming_group_id}" + response = requests.delete( + url, headers=self.api_client._headers(), verify=self.api_client.configuration.verify_ssl + ) + if response.status_code == 404: + raise NotFoundError(f"PrimingGroup '{priming_group_id}' not found.") + response.raise_for_status()