diff --git a/dlclivegui/assets/skeletons/__init__.py b/dlclivegui/assets/skeletons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlclivegui/config.py b/dlclivegui/config.py index 6d9e1de..8d56b12 100644 --- a/dlclivegui/config.py +++ b/dlclivegui/config.py @@ -12,6 +12,25 @@ TileLayout = Literal["auto", "2x2", "1x4", "4x1"] Precision = Literal["FP32", "FP16"] ModelType = Literal["pytorch", "tensorflow"] +BGR = tuple[int, int, int] # (B, G, R) color format + + +class SkeletonColorMode(str, Enum): + SOLID = "solid" + GRADIENT_KEYPOINTS = "gradient_keypoints" # use endpoint keypoint colors + + +class SkeletonStyle(BaseModel): + mode: SkeletonColorMode = SkeletonColorMode.SOLID + color: BGR = (0, 255, 255) # default if SOLID + thickness: int = 2 # base thickness in pixels + gradient_steps: int = 16 # segments per edge when gradient + scale_with_zoom: bool = True # scale thickness with (sx, sy) + + def effective_thickness(self, sx: float, sy: float) -> int: + if not self.scale_with_zoom: + return max(1, int(self.thickness)) + return max(1, int(round(self.thickness * min(sx, sy)))) class CameraSettings(BaseModel): @@ -301,12 +320,22 @@ class VisualizationSettings(BaseModel): colormap: str = "hot" bbox_color: tuple[int, int, int] = (0, 0, 255) + show_pose: bool = True + show_skeleton: bool = False + skeleton_style: SkeletonStyle = Field(default_factory=SkeletonStyle) + def get_bbox_color_bgr(self) -> tuple[int, int, int]: """Get bounding box color in BGR format""" if isinstance(self.bbox_color, (list, tuple)) and len(self.bbox_color) == 3: return tuple(int(c) for c in self.bbox_color) return (0, 0, 255) # default red + def get_skeleton_color_bgr(self) -> tuple[int, int, int]: + c = self.skeleton_style.color + if isinstance(c, (list, tuple)) and len(c) == 3: + return tuple(int(v) for v in c) + return (0, 255, 255) # default yellow + class RecordingSettings(BaseModel): enabled: bool = False @@ -315,6 +344,7 @@ class RecordingSettings(BaseModel): container: Literal["mp4", "avi", "mov"] = "mp4" codec: str = "libx264" crf: int = Field(default=23, ge=0, le=51) + with_overlays: bool = False def output_path(self) -> Path: """Return the absolute output path for recordings.""" diff --git a/dlclivegui/gui/main_window.py b/dlclivegui/gui/main_window.py index 0e3beb0..87b810c 100644 --- a/dlclivegui/gui/main_window.py +++ b/dlclivegui/gui/main_window.py @@ -42,6 +42,7 @@ QMainWindow, QMessageBox, QPushButton, + QScrollArea, QSizePolicy, QSpinBox, QStatusBar, @@ -62,6 +63,7 @@ VisualizationSettings, ) +from ..config import SkeletonColorMode, SkeletonStyle from ..processors.processor_utils import ( default_processors_dir, instantiate_from_scan, @@ -70,7 +72,8 @@ ) from ..services.dlc_processor import DLCLiveProcessor, PoseResult from ..services.multi_camera_controller import MultiCameraController, MultiFrameData, get_camera_id -from ..utils.display import BBoxColors, compute_tile_info, create_tiled_frame, draw_bbox, draw_pose +from ..utils import skeleton as skel +from ..utils.display import BBoxColors, compute_tile_info, create_tiled_frame, draw_bbox, draw_pose, keypoint_colors_bgr from ..utils.settings_store import DLCLiveGUISettingsStore, ModelPathStore from ..utils.stats import format_dlc_stats from ..utils.utils import FPSTracker @@ -164,6 +167,10 @@ def __init__(self, config: ApplicationSettings | None = None): self._p_cutoff = 0.6 self._colormap = "hot" self._bbox_color = (0, 0, 255) # BGR: red + ## Skeleton settings + self._skeleton: skel.Skeleton | None = None + self._skeleton_auto_disabled: bool = False + self._last_skeleton_disable_msg: str | None = None # Multi-camera state self._multi_camera_mode = False @@ -207,6 +214,8 @@ def __init__(self, config: ApplicationSettings | None = None): def resizeEvent(self, event): super().resizeEvent(event) + if hasattr(self, "controls_scroll"): + self._sync_controls_dock_min_width() if not self.multi_camera_controller.is_running(): self._show_logo_and_text() @@ -224,6 +233,34 @@ def _apply_theme(self, mode: AppStyle) -> None: def _load_icons(self): self.setWindowIcon(QIcon(LOGO)) + def _sync_controls_dock_min_width(self) -> None: + """Ensure the dock/scroll area is at least as wide as the controls content.""" + if not hasattr(self, "controls_scroll") or self.controls_scroll is None: + return + w = self.controls_scroll.widget() + if w is None: + return + + # Ensure layout has calculated its hints + w.adjustSize() + + # Minimum width needed by the controls content + content_w = w.minimumSizeHint().width() + if content_w <= 0: + content_w = w.sizeHint().width() + + # Reserve space for the vertical scrollbar (even if not currently visible) + vbar_w = self.controls_scroll.verticalScrollBar().sizeHint().width() + + # Account for scrollarea frame/margins + frame = self.controls_scroll.frameWidth() * 2 + + target = content_w + vbar_w + frame + + # Apply to both scroll area and dock (dock is what user resizes) + self.controls_scroll.setMinimumWidth(target) + self.controls_dock.setMinimumWidth(target) + def _setup_ui(self) -> None: # central = QWidget() # layout = QHBoxLayout(central) @@ -262,6 +299,7 @@ def _setup_ui(self) -> None: controls_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) controls_layout = QVBoxLayout(controls_widget) controls_layout.setContentsMargins(5, 5, 5, 5) + controls_layout.setSizeConstraint(QVBoxLayout.SizeConstraint.SetMinimumSize) controls_layout.addWidget(self._build_camera_group()) controls_layout.addWidget(self._build_dlc_group()) controls_layout.addWidget(self._build_recording_group()) @@ -287,7 +325,16 @@ def _setup_ui(self) -> None: ## Dock widget for controls self.controls_dock = QDockWidget("Controls", self) self.controls_dock.setObjectName("ControlsDock") # important for state saving - self.controls_dock.setWidget(controls_widget) + self.controls_scroll = QScrollArea() + controls_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Preferred) + self.controls_scroll.setWidget(controls_widget) + self.controls_scroll.setWidgetResizable(True) + self.controls_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.controls_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.controls_scroll.setFrameShape(QScrollArea.Shape.NoFrame) + self.controls_scroll.setSizeAdjustPolicy(QScrollArea.SizeAdjustPolicy.AdjustToContents) + # self.controls_dock.setWidget(controls_widget) + self.controls_dock.setWidget(self.controls_scroll) ### Dock features self.controls_dock.setFeatures( # must not be closable by user but visibility can be toggled from View -> Show controls @@ -310,6 +357,7 @@ def _setup_ui(self) -> None: self.setStatusBar(QStatusBar()) self._build_menus() + QTimer.singleShot(0, self._sync_controls_dock_min_width) # ensure dock is wide enough for controls after layout QTimer.singleShot(0, self._show_logo_and_text) def _build_stats_layout(self, stats_widget: QWidget) -> QGridLayout: @@ -664,8 +712,10 @@ def _build_recording_group(self) -> QGroupBox: return group def _build_viz_group(self) -> QGroupBox: + # Visualization settings group group = QGroupBox("Visualization") form = QFormLayout(group) + ## Pose overlay self.show_predictions_checkbox = QCheckBox("Display pose predictions") self.show_predictions_checkbox.setChecked(True) @@ -690,6 +740,47 @@ def _build_viz_group(self) -> QGroupBox: ) form.addRow(keypoints_settings) + ## Skeleton overlay + self.show_skeleton_checkbox = QCheckBox("Display skeleton") + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.setEnabled(False) + self.show_skeleton_checkbox.setToolTip( + "If enabled, draws connections between keypoints based on the model's skeleton definition.\n" + "Auto-disables if the model keypoints do not match the skeleton definition." + ) + + # Skeleton color mode / color + self.skeleton_color_combo = color_ui.make_skeleton_color_combo( + BBoxColors, # re-use PrimaryColors palette; you aliased it already + current_mode="solid", + current_color=(0, 255, 255), + include_icons=True, + tooltip="Select skeleton color, or Gradient to blend endpoint keypoint colors", + sizing=color_ui.ComboSizing(min_width=80, max_width=200), + ) + self.skeleton_color_combo.setEnabled(False) + + # Skeleton thickness + self.skeleton_thickness_spin = QSpinBox() + self.skeleton_thickness_spin.setRange(1, 20) + self.skeleton_thickness_spin.setValue(2) + self.skeleton_thickness_spin.setToolTip("Skeleton line thickness (scaled with zoom if enabled in style)") + self.skeleton_thickness_spin.setEnabled(False) + + # Layout like keypoints/bbox + skeleton_row = lyts.make_two_field_row( + "Skeleton:", + self.skeleton_color_combo, + None, + self.show_skeleton_checkbox, + key_width=120, + left_stretch=0, + right_stretch=0, + ) + form.addRow(skeleton_row) + form.addRow("Skeleton thickness:", self.skeleton_thickness_spin) + + ## Bounding box overlay & controls self.bbox_enabled_checkbox = QCheckBox("Show bounding box") self.bbox_enabled_checkbox.setChecked(False) @@ -740,7 +831,7 @@ def _build_viz_group(self) -> QGroupBox: self.bbox_y1_spin.setValue(100) bbox_layout.addWidget(self.bbox_y1_spin) - form.addRow("Coordinates", bbox_layout) + form.addRow("Box coordinates", bbox_layout) return group @@ -760,6 +851,10 @@ def _connect_signals(self) -> None: # Visualization settings ## Colormap change self.cmap_combo.currentIndexChanged.connect(self._on_colormap_changed) + ## Skeleton change + self.show_skeleton_checkbox.stateChanged.connect(self._on_show_skeleton_changed) + self.skeleton_color_combo.currentIndexChanged.connect(self._on_skeleton_style_changed) + self.skeleton_thickness_spin.valueChanged.connect(self._on_skeleton_style_changed) ## Connect bounding box controls self.bbox_enabled_checkbox.stateChanged.connect(self._on_bbox_changed) self.bbox_x0_spin.valueChanged.connect(self._on_bbox_changed) @@ -797,6 +892,45 @@ def _connect_signals(self) -> None: # ------------------------------------------------------------------ # Config + # ------------------------------------------------------------------ + def _apply_viz_settings_to_ui(self, viz: VisualizationSettings) -> None: + """Set UI state from VisualizationSettings (does not require skeleton to exist).""" + # Pose toggle + self.show_predictions_checkbox.blockSignals(True) + self.show_predictions_checkbox.setChecked(bool(viz.show_pose)) + self.show_predictions_checkbox.blockSignals(False) + + # Skeleton toggle (may remain disabled until skeleton exists) + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(bool(viz.show_skeleton)) + self.show_skeleton_checkbox.blockSignals(False) + + # Skeleton style controls (combo/spin) - set values even if disabled + if hasattr(self, "skeleton_color_combo"): + mode = viz.skeleton_style.mode.value # "solid" or "gradient_keypoints" + color = tuple(viz.skeleton_style.color) + color_ui.set_skeleton_combo_from_style(self.skeleton_color_combo, mode=mode, color=color) + + if hasattr(self, "skeleton_thickness_spin"): + self.skeleton_thickness_spin.blockSignals(True) + self.skeleton_thickness_spin.setValue(int(viz.skeleton_style.thickness)) + self.skeleton_thickness_spin.blockSignals(False) + + def _apply_viz_settings_to_skeleton(self, viz: VisualizationSettings) -> None: + """Apply VisualizationSettings onto the active runtime Skeleton, if present.""" + if self._skeleton is None: + return + + # Copy style fields + self._skeleton.style.mode = skel.SkeletonColorMode(viz.skeleton_style.mode.value) + self._skeleton.style.color = tuple(viz.skeleton_style.color) + self._skeleton.style.thickness = int(viz.skeleton_style.thickness) + self._skeleton.style.gradient_steps = int(viz.skeleton_style.gradient_steps) + self._skeleton.style.scale_with_zoom = bool(viz.skeleton_style.scale_with_zoom) + + # Enable/disable UI controls now that skeleton exists + self._sync_skeleton_controls_from_model() + def _apply_config(self, config: ApplicationSettings) -> None: # Update active cameras label self._update_active_cameras_label() @@ -813,6 +947,7 @@ def _apply_config(self, config: ApplicationSettings) -> None: self.output_directory_edit.setText(recording.directory) self.filename_edit.setText(recording.filename) self.container_combo.setCurrentText(recording.container) + self.record_with_overlays_checkbox.setChecked(recording.with_overlays) codec_index = self.codec_combo.findText(recording.codec) if codec_index >= 0: self.codec_combo.setCurrentIndex(codec_index) @@ -839,14 +974,20 @@ def _apply_config(self, config: ApplicationSettings) -> None: self.bbox_y1_spin.setValue(bbox.y1) # Set visualization settings from config + ## Keypoints viz = config.visualization self._p_cutoff = viz.p_cutoff self._colormap = viz.colormap + self._apply_viz_settings_to_ui(viz) if hasattr(self, "cmap_combo"): color_ui.set_cmap_combo_from_name(self.cmap_combo, self._colormap, fallback="viridis") self._bbox_color = viz.get_bbox_color_bgr() if hasattr(self, "bbox_color_combo"): color_ui.set_bbox_combo_from_bgr(self.bbox_color_combo, self._bbox_color) + ## Skeleton + if resolved_model_path.strip(): + self._configure_skeleton_for_model(resolved_model_path) + self._apply_viz_settings_to_skeleton(viz) # Update DLC camera list self._refresh_dlc_camera_list() @@ -909,6 +1050,7 @@ def _recording_settings_from_ui(self) -> RecordingSettings: container=self.container_combo.currentText().strip() or "mp4", codec=self.codec_combo.currentText().strip() or "libx264", crf=int(self.crf_spin.value()), + with_overlays=self.record_with_overlays_checkbox.isChecked(), ) def _bbox_settings_from_ui(self) -> BoundingBoxSettings: @@ -921,10 +1063,29 @@ def _bbox_settings_from_ui(self) -> BoundingBoxSettings: ) def _visualization_settings_from_ui(self) -> VisualizationSettings: + # Read skeleton mode+color from combo + mode_str, color = color_ui.get_skeleton_style_from_combo( + self.skeleton_color_combo, + fallback_mode="solid", + fallback_color=(0, 255, 255), + ) + + # Build SkeletonStyle (pydantic) + style = SkeletonStyle( + mode=SkeletonColorMode(mode_str), # or SkeletonColorMode.GRADIENT_KEYPOINTS if mode_str matches + color=tuple(color) if color else (0, 255, 255), + thickness=int(self.skeleton_thickness_spin.value()), + gradient_steps=getattr(self._skeleton.style, "gradient_steps", 16) if self._skeleton else 16, + scale_with_zoom=getattr(self._skeleton.style, "scale_with_zoom", True) if self._skeleton else True, + ) + return VisualizationSettings( p_cutoff=self._p_cutoff, colormap=self._colormap, bbox_color=self._bbox_color, + show_pose=self.show_predictions_checkbox.isChecked(), + show_skeleton=self.show_skeleton_checkbox.isChecked(), + skeleton_style=style, ) # ------------------------------------------------------------------ @@ -1018,6 +1179,7 @@ def _action_browse_model(self) -> None: return file_path = str(file_path) self.model_path_edit.setText(file_path) + self._configure_skeleton_for_model(file_path) # Persist model path + directory self._model_path_store.save_if_valid(file_path) @@ -1180,6 +1342,26 @@ def _on_bbox_color_changed(self, _index: int) -> None: if self._current_frame is not None: self._display_frame(self._current_frame, force=True) + def _on_show_skeleton_changed(self, _state: int) -> None: + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + if self._current_frame is not None: + self._display_frame(self._current_frame, force=True) + + def _on_skeleton_style_changed(self, _value: int = 0) -> None: + if self._skeleton is None: + return + + mode_str, color = color_ui.get_skeleton_style_from_combo(self.skeleton_color_combo) + self._skeleton.style.mode = skel.SkeletonColorMode(mode_str) + if self._skeleton.style.mode == skel.SkeletonColorMode.SOLID and color is not None: + self._skeleton.style.color = tuple(color) + + self._skeleton.style.thickness = int(self.skeleton_thickness_spin.value()) + + if self._current_frame is not None: + self._display_frame(self._current_frame, force=True) + # ------------------------------------------------------------------ # Multi-camera def _open_camera_config_dialog(self) -> None: @@ -1338,6 +1520,13 @@ def _render_overlays_for_recording(self, cam_id, frame): offset=offset, scale=scale, ) + self._draw_skeleton_on_frame( + output, + self._last_pose.pose, + offset=offset, + scale=scale, + ) + if self._bbox_enabled: output = draw_bbox( frame=output, @@ -1608,6 +1797,76 @@ def _stop_preview(self) -> None: self.camera_stats_label.setText("Camera idle") # self._show_logo_and_text() + def _sync_skeleton_controls_from_model(self) -> None: + """Enable and initialize skeleton UI controls from the current Skeleton.style.""" + enabled = self._skeleton is not None + + self.show_skeleton_checkbox.setEnabled(enabled) + self.skeleton_color_combo.setEnabled(enabled) + self.skeleton_thickness_spin.setEnabled(enabled) + + if not enabled: + return + + # Set thickness + self.skeleton_thickness_spin.blockSignals(True) + self.skeleton_thickness_spin.setValue(int(self._skeleton.style.thickness)) + self.skeleton_thickness_spin.blockSignals(False) + + # Set color/mode + mode = ( + "gradient_keypoints" if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS else "solid" + ) + color_ui.set_skeleton_combo_from_style( + self.skeleton_color_combo, + mode=mode, + color=self._skeleton.style.color, + ) + + if hasattr(self.skeleton_color_combo, "update_shrink_width"): + self.skeleton_color_combo.update_shrink_width() + + def _configure_skeleton_for_model(self, model_path: str) -> None: + """Select an appropriate skeleton definition for the currently configured model.""" + self._skeleton = None + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + + # Default: disable until we find a compatible skeleton + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.setEnabled(False) + # keep checked state but it won't be used unless enabled + + p = Path(model_path).expanduser() + + root = p if p.is_dir() else p.parent + cfg = root / "config.yaml" + if cfg.exists() and self._skeleton is None: + try: + sk = skel.load_dlc_skeleton(cfg) + except Exception as e: + logger.warning(f"Failed to load DLC skeleton from {cfg}: {e}") + sk = None + + if sk is not None: + self._skeleton = sk + self._sync_skeleton_controls_from_model() + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.setEnabled(True) + self.statusBar().showMessage("Skeleton available: DLC config.yaml", 3000) + + if self._skeleton is not None: + try: + viz = self._config.visualization + self._apply_viz_settings_to_skeleton(viz) + except Exception as e: + logger.warning(f"Failed to apply visualization settings to skeleton: {e}") + pass + return + + # None found + self.statusBar().showMessage("No skeleton definition available for this model.", 3000) + def _configure_dlc(self) -> bool: try: settings = self._dlc_settings_from_ui() @@ -1639,6 +1898,7 @@ def _configure_dlc(self) -> bool: self.statusBar().showMessage(f"Processor selection ignored (control disabled): {selected_key}", 3000) self._dlc.configure(settings, processor=processor) + self._configure_skeleton_for_model(settings.model_path) self._model_path_store.save_if_valid(settings.model_path) return True @@ -1864,6 +2124,18 @@ def _stop_inference(self, show_message: bool = True) -> None: self._last_processor_vid_recording = False self._auto_record_session_name = None + # Reset skeleton + self._skeleton = None + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + self.skeleton_color_combo.setEnabled(False) + self.skeleton_thickness_spin.setEnabled(False) + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.setEnabled(False) + self.show_skeleton_checkbox.blockSignals(False) + # Reset button appearance self.start_inference_button.setText("Start pose inference") self.start_inference_button.setStyleSheet("") @@ -1908,6 +2180,69 @@ def _on_dlc_error(self, message: str) -> None: self._stop_inference(show_message=False) self._show_error(message) + def _draw_skeleton_on_frame( + self, + overlay: np.ndarray, + pose: np.ndarray, + *, + offset: tuple[int, int], + scale: tuple[float, float], + allow_auto_disable: bool = True, + ) -> skel.SkeletonRenderStatus | None: + """Draw skeleton on overlay with correct style. Optionally auto-disables UI on mismatch.""" + if self._skeleton is None: + return None + if not self.show_skeleton_checkbox.isChecked(): + return None + if self._skeleton_auto_disabled: + return None + + pose_arr = np.asarray(pose) + + # Provide keypoint_colors iff gradient mode is active + kp_colors = None + if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS: + n_kpts = pose_arr.shape[1] if pose_arr.ndim == 3 else pose_arr.shape[0] + kp_colors = keypoint_colors_bgr(self._colormap, int(n_kpts)) + + try: + if pose_arr.ndim == 3: + status = self._skeleton.draw_many( + overlay, + pose_arr, + p_cutoff=self._p_cutoff, + offset=offset, + scale=scale, + keypoint_colors=kp_colors, + ) + else: + status = self._skeleton.draw( + overlay, + pose_arr, + p_cutoff=self._p_cutoff, + offset=offset, + scale=scale, + keypoint_colors=kp_colors, + ) + except Exception as e: + status = skel.SkeletonRenderStatus( + code=skel.SkeletonRenderCode.POSE_SHAPE_INVALID, + message=f"Skeleton rendering error: {e}", + ) + + if allow_auto_disable and status.should_disable: + self._skeleton_auto_disabled = True + msg = status.message or "Skeleton disabled due to mismatch." + if msg != self._last_skeleton_disable_msg: + self._last_skeleton_disable_msg = msg + self.statusBar().showMessage(f"Skeleton disabled: {msg}", 6000) + + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.blockSignals(False) + + return status + def _update_video_display(self, frame: np.ndarray) -> None: display_frame = frame @@ -1921,6 +2256,13 @@ def _update_video_display(self, frame: np.ndarray) -> None: scale=self._dlc_tile_scale, ) + self._draw_skeleton_on_frame( + display_frame, + self._last_pose.pose, + offset=self._dlc_tile_offset, + scale=self._dlc_tile_scale, + ) + if self._bbox_enabled: display_frame = draw_bbox( display_frame, diff --git a/dlclivegui/gui/misc/color_dropdowns.py b/dlclivegui/gui/misc/color_dropdowns.py index bb0f0ac..6f1351a 100644 --- a/dlclivegui/gui/misc/color_dropdowns.py +++ b/dlclivegui/gui/misc/color_dropdowns.py @@ -16,7 +16,7 @@ import numpy as np from PySide6.QtCore import Qt -from PySide6.QtGui import QColor, QIcon, QImage, QPainter, QPixmap +from PySide6.QtGui import QBrush, QColor, QIcon, QImage, QLinearGradient, QPainter, QPixmap from PySide6.QtWidgets import ( QComboBox, QSizePolicy, @@ -24,10 +24,31 @@ QStyleOptionComboBox, ) -BGR = tuple[int, int, int] +from dlclivegui.config import BGR + TEnum = TypeVar("TEnum") +def make_gradient_swatch_icon(*, width: int = 40, height: int = 16, border: int = 1) -> QIcon: + """Small gradient swatch icon for 'Gradient' mode.""" + pix = QPixmap(width, height) + pix.fill(Qt.transparent) + p = QPainter(pix) + + # border/background + p.fillRect(0, 0, width, height, Qt.black) + p.fillRect(border, border, width - 2 * border, height - 2 * border, Qt.white) + + # inner gradient (blue -> red, arbitrary but clearly "gradient") + grad = QLinearGradient(border + 1, 0, width - (border + 1), 0) + grad.setColorAt(0.0, QColor(0, 140, 255)) + grad.setColorAt(1.0, QColor(255, 80, 0)) + p.fillRect(border + 1, border + 1, width - 2 * (border + 1), height - 2 * (border + 1), QBrush(grad)) + + p.end() + return QIcon(pix) + + # ----------------------------------------------------------------------------- # Combo sizing: shrink to current selection + wide popup # ----------------------------------------------------------------------------- @@ -413,3 +434,114 @@ def get_cmap_name_from_combo(combo: QComboBox, *, fallback: str = "viridis") -> return data text = combo.currentText().strip() return text or fallback + + +# ----------------------------------------------------------------------------- +# Skeleton color combo helpers (enum-based + gradient) +# ----------------------------------------------------------------------------- +def get_skeleton_style_from_combo( + combo: QComboBox, + *, + fallback_mode: str = "solid", + fallback_color: BGR | None = None, +) -> tuple[str, BGR | None]: + data = combo.currentData() + if isinstance(data, dict): + mode = data.get("mode", fallback_mode) + color = data.get("color", fallback_color) + return mode, color + return fallback_mode, fallback_color + + +def make_skeleton_color_combo( + colors_enum: Iterable[TEnum], + *, + current_mode: str = "solid", + current_color: BGR | None = (0, 255, 255), + include_icons: bool = True, + tooltip: str = "Select skeleton line color or Gradient (from keypoints)", + sizing: ComboSizing | None = None, +) -> QComboBox: + combo = ShrinkCurrentWidePopupComboBox(sizing=sizing) if sizing is not None else QComboBox() + combo.setToolTip(tooltip) + populate_skeleton_color_combo( + combo, + colors_enum, + current_mode=current_mode, + current_color=current_color, + include_icons=include_icons, + ) + if isinstance(combo, ShrinkCurrentWidePopupComboBox): + combo.update_shrink_width() + return combo + + +def set_skeleton_combo_from_style(combo: QComboBox, *, mode: str, color: BGR | None) -> None: + """Select the best matching item.""" + # Gradient + if mode == "gradient_keypoints": + combo.findData({"mode": "gradient_keypoints"}) # may fail due to dict identity + # robust fallback: scan + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "gradient_keypoints": + combo.setCurrentIndex(i) + return + return + + # Solid with color + if color is not None: + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "solid" and tuple(d.get("color")) == tuple(color): + combo.setCurrentIndex(i) + return + + # Default: first solid entry + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "solid": + combo.setCurrentIndex(i) + return + + +def populate_skeleton_color_combo( + combo: QComboBox, + colors_enum: Iterable[TEnum], + *, + current_mode: str = "solid", + current_color: BGR | None = None, + include_icons: bool = True, + gradient_label: str = "Gradient (from keypoints)", +) -> None: + """ + Populate combo with: + - Gradient mode + - Solid colors (from enum values BGR) + ItemData is a dict with keys: + - mode: 'solid' or 'gradient_keypoints' + - color: optional BGR for solid + """ + combo.blockSignals(True) + combo.clear() + + # 1) Gradient option + if include_icons: + combo.addItem(make_gradient_swatch_icon(), gradient_label, {"mode": "gradient_keypoints"}) + else: + combo.addItem(gradient_label, {"mode": "gradient_keypoints"}) + + # 2) Solid colors + for enum_item in colors_enum: + bgr: BGR = enum_item.value + name = getattr(enum_item, "name", str(enum_item)).title() + data = {"mode": "solid", "color": bgr} + if include_icons: + combo.addItem(make_bgr_swatch_icon(bgr), name, data) + else: + combo.addItem(name, data) + + # Select current + set_skeleton_combo_from_style(combo, mode=current_mode, color=current_color) + + combo.blockSignals(False) diff --git a/dlclivegui/temp/yolo/__init__.py b/dlclivegui/temp/yolo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlclivegui/utils/display.py b/dlclivegui/utils/display.py index 0eac657..05c3408 100644 --- a/dlclivegui/utils/display.py +++ b/dlclivegui/utils/display.py @@ -8,7 +8,7 @@ import numpy as np -class BBoxColors(enum.Enum): +class PrimaryColors(enum.Enum): RED = (0, 0, 255) GREEN = (0, 255, 0) BLUE = (255, 0, 0) @@ -20,7 +20,26 @@ class BBoxColors(enum.Enum): @staticmethod def get_all_display_names() -> list[str]: - return [color.name.capitalize() for color in BBoxColors] + return [c.name.capitalize() for c in PrimaryColors] + + +BBoxColors = PrimaryColors + + +class SkeletonColors(enum.Enum): + GRADIENT = "gradient" # special mode + RED = PrimaryColors.RED.value + GREEN = PrimaryColors.GREEN.value + BLUE = PrimaryColors.BLUE.value + YELLOW = PrimaryColors.YELLOW.value + CYAN = PrimaryColors.CYAN.value + MAGENTA = PrimaryColors.MAGENTA.value + WHITE = PrimaryColors.WHITE.value + BLACK = PrimaryColors.BLACK.value + + @staticmethod + def get_all_display_names() -> list[str]: + return ["Gradient"] + [c.name.capitalize() for c in SkeletonColors if c != SkeletonColors.GRADIENT] def color_to_rgb(color_name: str) -> tuple[int, int, int]: @@ -31,6 +50,21 @@ def color_to_rgb(color_name: str) -> tuple[int, int, int]: raise ValueError(f"Unknown color name: {color_name}") from None +def keypoint_colors_bgr(colormap: str, num_keypoints: int) -> list[tuple[int, int, int]]: + """ + Return the exact BGR colors used by draw_keypoints() for a given Matplotlib colormap + and number of keypoints. + """ + cmap = plt.get_cmap(colormap) + colors: list[tuple[int, int, int]] = [] + for idx in range(num_keypoints): + t = idx / max(num_keypoints - 1, 1) + rgba = cmap(t) + bgr = (int(rgba[2] * 255), int(rgba[1] * 255), int(rgba[0] * 255)) + colors.append(bgr) + return colors + + def compute_tiling_geometry( frames: dict[str, np.ndarray], max_canvas: tuple[int, int] = (1200, 800), diff --git a/dlclivegui/utils/skeleton.py b/dlclivegui/utils/skeleton.py new file mode 100644 index 0000000..1bacbf2 --- /dev/null +++ b/dlclivegui/utils/skeleton.py @@ -0,0 +1,367 @@ +"""Skeleton definition, validation, and drawing utilities.""" + +# dlclivegui/utils/skeleton.py +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path + +import cv2 +import numpy as np +import yaml +from pydantic import BaseModel, Field, ValidationError, field_validator + +from dlclivegui.config import BGR, SkeletonColorMode, SkeletonStyle + + +# ############### # +# Status & code # +# ############### # +class SkeletonRenderCode(Enum): + OK = auto() + POSE_SHAPE_INVALID = auto() + KEYPOINT_COUNT_MISMATCH = auto() + + +@dataclass(frozen=True) +class SkeletonRenderStatus: + code: SkeletonRenderCode + message: str = "" + + @property + def rendered(self) -> bool: + return self.code == SkeletonRenderCode.OK + + @property + def should_disable(self) -> bool: + # GUI can switch off skeleton drawing if True + return self.code in { + SkeletonRenderCode.POSE_SHAPE_INVALID, + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + } + + +# ############ # +# Exceptions # +# ############ # + + +class SkeletonError(ValueError): + """Raised when a skeleton definition is invalid.""" + + +class SkeletonLoadError(Exception): + """High-level skeleton loading error (safe for GUI display).""" + + +class SkeletonValidationError(SkeletonLoadError): + """Schema or semantic validation error.""" + + +# ################## # +# Skeleton display # +# ################## # + + +class SkeletonStyleModel(BaseModel): + mode: SkeletonColorMode = SkeletonColorMode.SOLID + color: BGR = (0, 255, 255) # default if SOLID + thickness: int = Field(2, ge=1, description="Base thickness in pixels") + gradient_steps: int = Field(16, ge=2, description="Segments per edge when gradient") + scale_with_zoom: bool = True + + @field_validator("thickness") + @classmethod + def _thickness_positive(cls, v): + if v < 1: + raise ValueError("Thickness must be at least 1 pixel") + return v + + @field_validator("gradient_steps") + @classmethod + def _steps_positive(cls, v): + if v < 2: + raise ValueError("gradient_steps must be >= 2") + return v + + +# ############# # +# Skeleton IO # +# ############# # +class SkeletonModel(BaseModel): + """Validated skeleton definition (IO + schema).""" + + name: str | None = None + + keypoints: list[str] = Field(..., min_length=1, description="Ordered list of keypoint names") + + edges: list[tuple[int, int]] = Field( + default_factory=list, + description="List of (i, j) keypoint index pairs", + ) + + style: SkeletonStyleModel = Field(default_factory=SkeletonStyleModel) + default_color: BGR = (0, 255, 255) # used if style.color is None or in SOLID mode + edge_colors: dict[tuple[int, int], BGR] = Field(default_factory=dict) + + schema_version: int = 1 + + @field_validator("keypoints") + @classmethod + def validate_unique_keypoints(cls, v): + if len(set(v)) != len(v): + raise ValueError("Duplicate keypoint names detected") + return v + + @field_validator("edges") + @classmethod + def validate_edges(cls, edges, info): + keypoints = info.data.get("keypoints", []) + n = len(keypoints) + + for i, j in edges: + if i == j: + raise ValueError(f"Self-loop detected in edge ({i}, {j})") + if not (0 <= i < n and 0 <= j < n): + raise ValueError(f"Edge ({i}, {j}) out of range for {n} keypoints") + return edges + + +def _load_raw_skeleton_data(path: Path) -> dict: + if not path.exists(): + raise SkeletonLoadError(f"Skeleton file not found: {path}") + + if path.suffix in {".yaml", ".yml"}: + return yaml.safe_load(path.read_text()) + + if path.suffix == ".json": + return json.loads(path.read_text()) + + raise SkeletonLoadError(f"Unsupported file type: {path.suffix}") + + +def _format_pydantic_error(err: ValidationError) -> str: + lines = ["Invalid skeleton definition:"] + for e in err.errors(): + loc = " → ".join(map(str, e["loc"])) + msg = e["msg"] + lines.append(f"• {loc}: {msg}") + return "\n".join(lines) + + +def load_skeleton(path: Path) -> Skeleton: + try: + data = _load_raw_skeleton_data(path) + model = SkeletonModel.model_validate(data) + return Skeleton(model) + + except ValidationError as e: + raise SkeletonValidationError(_format_pydantic_error(e)) from None + + except Exception as e: + raise SkeletonLoadError(str(e)) from None + + +def save_skeleton(path: Path, model: SkeletonModel) -> None: + data = model.model_dump() + + if path.suffix in {".yaml", ".yml"}: + path.write_text(yaml.safe_dump(data, sort_keys=False)) + elif path.suffix == ".json": + path.write_text(json.dumps(data, indent=2)) + else: + raise SkeletonLoadError(f"Unsupported skeleton file type: {path.suffix}") + + +def load_dlc_skeleton(config_path: Path) -> Skeleton | None: + if not config_path.exists(): + raise SkeletonLoadError(f"DLC config not found: {config_path}") + + cfg = yaml.safe_load(config_path.read_text()) + + bodyparts = cfg.get("bodyparts") + if not bodyparts: + return None # No pose info + + edges = [] + + # Newer DLC format + if "skeleton" in cfg: + for a, b in cfg["skeleton"]: + edges.append((bodyparts.index(a), bodyparts.index(b))) + + # Older / alternative formats + elif "skeleton_edges" in cfg: + edges = [tuple(e) for e in cfg["skeleton_edges"]] + + if not edges: + return None + + model = SkeletonModel( + name=cfg.get("Task", "DeepLabCut"), + keypoints=bodyparts, + edges=edges, + ) + + return Skeleton(model) + + +class Skeleton: + """Runtime skeleton optimized for drawing.""" + + def __init__(self, model: SkeletonModel): + self.name = model.name + self.keypoints = model.keypoints + self.edges = model.edges + + self.style = SkeletonStyle( + mode=model.style.mode, + color=model.style.color, + thickness=model.style.thickness, + gradient_steps=model.style.gradient_steps, + scale_with_zoom=model.style.scale_with_zoom, + ) + self.default_color = model.default_color + self.edge_colors = model.edge_colors + + def check_pose_compat(self, pose: np.ndarray) -> SkeletonRenderStatus: + pose = np.asarray(pose) + + if pose.ndim != 2 or pose.shape[1] not in (2, 3): + return SkeletonRenderStatus( + SkeletonRenderCode.POSE_SHAPE_INVALID, + f"Pose must be (N,2) or (N,3); got shape={pose.shape}", + ) + + expected = len(self.keypoints) + got = pose.shape[0] + if got != expected: + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Skeleton expects {expected} keypoints, but pose has {got}.", + ) + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "") + + def _draw_gradient_edge( + self, + img: np.ndarray, + p1: tuple[int, int], + p2: tuple[int, int], + c1: BGR, + c2: BGR, + thickness: int, + steps: int, + ): + x1, y1 = p1 + x2, y2 = p2 + + for s in range(steps): + a0 = s / steps + a1 = (s + 1) / steps + xs0 = int(x1 + (x2 - x1) * a0) + ys0 = int(y1 + (y2 - y1) * a0) + xs1 = int(x1 + (x2 - x1) * a1) + ys1 = int(y1 + (y2 - y1) * a1) + + t = (s + 0.5) / steps + b = int(c1[0] + (c2[0] - c1[0]) * t) + g = int(c1[1] + (c2[1] - c1[1]) * t) + r = int(c1[2] + (c2[2] - c1[2]) * t) + + cv2.line(img, (xs0, ys0), (xs1, ys1), (b, g, r), thickness, lineType=cv2.LINE_AA) + + def draw( + self, + overlay: np.ndarray, + pose: np.ndarray, + p_cutoff: float, + offset: tuple[int, int], + scale: tuple[float, float], + *, + style: SkeletonStyle | None = None, + color_override: BGR | None = None, + keypoint_colors: list[BGR] | None = None, + ) -> SkeletonRenderStatus: + status = self.check_pose_compat(pose) + if not status.rendered: + return status + + st = style or self.style + ox, oy = offset + sx, sy = scale + th = st.effective_thickness(sx, sy) + + # if gradient mode, require keypoint_colors aligned with keypoint order + if st.mode == SkeletonColorMode.GRADIENT_KEYPOINTS: + if keypoint_colors is None or len(keypoint_colors) != len(self.keypoints): + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Gradient mode requires keypoint_colors of length {len(self.keypoints)}.", + ) + + for i, j in self.edges: + xi, yi = pose[i][:2] + xj, yj = pose[j][:2] + ci = pose[i][2] if pose.shape[1] > 2 else 1.0 + cj = pose[j][2] if pose.shape[1] > 2 else 1.0 + if np.isnan(xi) or np.isnan(yi) or ci < p_cutoff or np.isnan(xj) or np.isnan(yj) or cj < p_cutoff: + continue + + p1 = (int(xi * sx + ox), int(yi * sy + oy)) + p2 = (int(xj * sx + ox), int(yj * sy + oy)) + + if st.mode == SkeletonColorMode.GRADIENT_KEYPOINTS: + c1 = keypoint_colors[i] + c2 = keypoint_colors[j] + self._draw_gradient_edge(overlay, p1, p2, c1, c2, th, st.gradient_steps) + else: + # SOLID: priority edge_colors > override > style.color > default_color + color = self.edge_colors.get((i, j), color_override or st.color or self.default_color) + cv2.line(overlay, p1, p2, color, th, lineType=cv2.LINE_AA) + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "") + + def draw_many( + self, + overlay: np.ndarray, + poses: np.ndarray, + p_cutoff: float, + offset: tuple[int, int], + scale: tuple[float, float], + *, + style: SkeletonStyle | None = None, + color_override: BGR | None = None, + keypoint_colors: list[BGR] | None = None, + ) -> SkeletonRenderStatus: + poses = np.asarray(poses) + if poses.ndim != 3: + return SkeletonRenderStatus( + SkeletonRenderCode.POSE_SHAPE_INVALID, + f"Multi-pose must be (A,N,2/3); got shape={poses.shape}", + ) + + expected = len(self.keypoints) + if poses.shape[1] != expected: + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Skeleton expects {expected} keypoints, but poses have N={poses.shape[1]}.", + ) + + for pose in poses: + st = self.draw( + overlay, + pose, + p_cutoff, + offset, + scale, + style=style, + color_override=color_override, + keypoint_colors=keypoint_colors, + ) + if not st.rendered: + return st + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "")