From 2f3ebb1371fc289df357dd6b4643484cfb676faf Mon Sep 17 00:00:00 2001 From: Cyril Achard Date: Wed, 25 Feb 2026 10:40:33 +0100 Subject: [PATCH 1/2] Add skeleton rendering and GUI integration Introduce a new skeleton utility and integrate skeleton overlays into the GUI. Adds dlclivegui/utils/skeleton.py (SkeletonModel, Skeleton, loaders, render status/codes) and a helper to load DLC config skeleton. Wire skeleton drawing into display.draw_pose and the main window: add UI checkbox, auto-enable/disable logic, model-based skeleton configuration, and safe handling when keypoint counts or shapes mismatch. Also add a QScrollArea for the controls dock to allow scrolling. Remove unused demo code Add skeleton UI controls and gradient coloring Add UI and rendering support for configurable skeleton appearance: color mode (solid or keypoint-gradient) and line thickness. main_window.py: introduce skeleton color combo, thickness spinbox, handlers (_on_skeleton_style_changed, _sync_skeleton_controls_from_model), and wire these into model loading and drawing flow; enable/disable controls appropriately and preserve auto-disable behavior. gui/misc/color_dropdowns.py: add gradient swatch icon and helpers to create/populate/get/set a skeleton color combo (supports Gradient and solid BGR swatches). utils/display.py: add keypoint_colors_bgr(colormap, num_keypoints) to produce exact BGR colors from a Matplotlib colormap and remove direct skeleton coupling from draw_pose. utils/skeleton.py: ensure draw_many forwards style, color_override and keypoint_colors to per-pose draw calls and validates pose shapes/keypoint counts. Overall this enables gradient coloring of skeleton lines based on keypoint colormap and exposes user controls to tweak skeleton rendering. --- dlclivegui/assets/skeletons/__init__.py | 0 dlclivegui/gui/main_window.py | 279 ++++++++++++++++- dlclivegui/gui/misc/color_dropdowns.py | 133 +++++++- dlclivegui/temp/yolo/__init__.py | 0 dlclivegui/utils/display.py | 38 ++- dlclivegui/utils/skeleton.py | 384 ++++++++++++++++++++++++ 6 files changed, 828 insertions(+), 6 deletions(-) create mode 100644 dlclivegui/assets/skeletons/__init__.py create mode 100644 dlclivegui/temp/yolo/__init__.py create mode 100644 dlclivegui/utils/skeleton.py diff --git a/dlclivegui/assets/skeletons/__init__.py b/dlclivegui/assets/skeletons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlclivegui/gui/main_window.py b/dlclivegui/gui/main_window.py index 0e3beb0..798870d 100644 --- a/dlclivegui/gui/main_window.py +++ b/dlclivegui/gui/main_window.py @@ -42,6 +42,7 @@ QMainWindow, QMessageBox, QPushButton, + QScrollArea, QSizePolicy, QSpinBox, QStatusBar, @@ -70,7 +71,8 @@ ) from ..services.dlc_processor import DLCLiveProcessor, PoseResult from ..services.multi_camera_controller import MultiCameraController, MultiFrameData, get_camera_id -from ..utils.display import BBoxColors, compute_tile_info, create_tiled_frame, draw_bbox, draw_pose +from ..utils import skeleton as skel +from ..utils.display import BBoxColors, compute_tile_info, create_tiled_frame, draw_bbox, draw_pose, keypoint_colors_bgr from ..utils.settings_store import DLCLiveGUISettingsStore, ModelPathStore from ..utils.stats import format_dlc_stats from ..utils.utils import FPSTracker @@ -164,6 +166,10 @@ def __init__(self, config: ApplicationSettings | None = None): self._p_cutoff = 0.6 self._colormap = "hot" self._bbox_color = (0, 0, 255) # BGR: red + ## Skeleton settings + self._skeleton: skel.Skeleton | None = None + self._skeleton_auto_disabled: bool = False + self._last_skeleton_disable_msg: str | None = None # Multi-camera state self._multi_camera_mode = False @@ -207,6 +213,8 @@ def __init__(self, config: ApplicationSettings | None = None): def resizeEvent(self, event): super().resizeEvent(event) + if hasattr(self, "controls_scroll"): + self._sync_controls_dock_min_width() if not self.multi_camera_controller.is_running(): self._show_logo_and_text() @@ -224,6 +232,34 @@ def _apply_theme(self, mode: AppStyle) -> None: def _load_icons(self): self.setWindowIcon(QIcon(LOGO)) + def _sync_controls_dock_min_width(self) -> None: + """Ensure the dock/scroll area is at least as wide as the controls content.""" + if not hasattr(self, "controls_scroll") or self.controls_scroll is None: + return + w = self.controls_scroll.widget() + if w is None: + return + + # Ensure layout has calculated its hints + w.adjustSize() + + # Minimum width needed by the controls content + content_w = w.minimumSizeHint().width() + if content_w <= 0: + content_w = w.sizeHint().width() + + # Reserve space for the vertical scrollbar (even if not currently visible) + vbar_w = self.controls_scroll.verticalScrollBar().sizeHint().width() + + # Account for scrollarea frame/margins + frame = self.controls_scroll.frameWidth() * 2 + + target = content_w + vbar_w + frame + + # Apply to both scroll area and dock (dock is what user resizes) + self.controls_scroll.setMinimumWidth(target) + self.controls_dock.setMinimumWidth(target) + def _setup_ui(self) -> None: # central = QWidget() # layout = QHBoxLayout(central) @@ -262,6 +298,7 @@ def _setup_ui(self) -> None: controls_widget.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) controls_layout = QVBoxLayout(controls_widget) controls_layout.setContentsMargins(5, 5, 5, 5) + controls_layout.setSizeConstraint(QVBoxLayout.SizeConstraint.SetMinimumSize) controls_layout.addWidget(self._build_camera_group()) controls_layout.addWidget(self._build_dlc_group()) controls_layout.addWidget(self._build_recording_group()) @@ -287,7 +324,16 @@ def _setup_ui(self) -> None: ## Dock widget for controls self.controls_dock = QDockWidget("Controls", self) self.controls_dock.setObjectName("ControlsDock") # important for state saving - self.controls_dock.setWidget(controls_widget) + self.controls_scroll = QScrollArea() + controls_widget.setSizePolicy(QSizePolicy.MinimumExpanding, QSizePolicy.Preferred) + self.controls_scroll.setWidget(controls_widget) + self.controls_scroll.setWidgetResizable(True) + self.controls_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.controls_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAsNeeded) + self.controls_scroll.setFrameShape(QScrollArea.Shape.NoFrame) + self.controls_scroll.setSizeAdjustPolicy(QScrollArea.SizeAdjustPolicy.AdjustToContents) + # self.controls_dock.setWidget(controls_widget) + self.controls_dock.setWidget(self.controls_scroll) ### Dock features self.controls_dock.setFeatures( # must not be closable by user but visibility can be toggled from View -> Show controls @@ -310,6 +356,7 @@ def _setup_ui(self) -> None: self.setStatusBar(QStatusBar()) self._build_menus() + QTimer.singleShot(0, self._sync_controls_dock_min_width) # ensure dock is wide enough for controls after layout QTimer.singleShot(0, self._show_logo_and_text) def _build_stats_layout(self, stats_widget: QWidget) -> QGridLayout: @@ -664,8 +711,10 @@ def _build_recording_group(self) -> QGroupBox: return group def _build_viz_group(self) -> QGroupBox: + # Visualization settings group group = QGroupBox("Visualization") form = QFormLayout(group) + ## Pose overlay self.show_predictions_checkbox = QCheckBox("Display pose predictions") self.show_predictions_checkbox.setChecked(True) @@ -690,6 +739,47 @@ def _build_viz_group(self) -> QGroupBox: ) form.addRow(keypoints_settings) + ## Skeleton overlay + self.show_skeleton_checkbox = QCheckBox("Display skeleton") + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.setEnabled(False) + self.show_skeleton_checkbox.setToolTip( + "If enabled, draws connections between keypoints based on the model's skeleton definition.\n" + "Auto-disables if the model keypoints do not match the skeleton definition." + ) + + # Skeleton color mode / color + self.skeleton_color_combo = color_ui.make_skeleton_color_combo( + BBoxColors, # re-use PrimaryColors palette; you aliased it already + current_mode="solid", + current_color=(0, 255, 255), + include_icons=True, + tooltip="Select skeleton color, or Gradient to blend endpoint keypoint colors", + sizing=color_ui.ComboSizing(min_width=80, max_width=200), + ) + self.skeleton_color_combo.setEnabled(False) + + # Skeleton thickness + self.skeleton_thickness_spin = QSpinBox() + self.skeleton_thickness_spin.setRange(1, 20) + self.skeleton_thickness_spin.setValue(2) + self.skeleton_thickness_spin.setToolTip("Skeleton line thickness (scaled with zoom if enabled in style)") + self.skeleton_thickness_spin.setEnabled(False) + + # Layout like keypoints/bbox + skeleton_row = lyts.make_two_field_row( + "Skeleton:", + self.skeleton_color_combo, + None, + self.show_skeleton_checkbox, + key_width=120, + left_stretch=0, + right_stretch=0, + ) + form.addRow(skeleton_row) + form.addRow("Skeleton thickness:", self.skeleton_thickness_spin) + + ## Bounding box overlay & controls self.bbox_enabled_checkbox = QCheckBox("Show bounding box") self.bbox_enabled_checkbox.setChecked(False) @@ -740,7 +830,7 @@ def _build_viz_group(self) -> QGroupBox: self.bbox_y1_spin.setValue(100) bbox_layout.addWidget(self.bbox_y1_spin) - form.addRow("Coordinates", bbox_layout) + form.addRow("Box coordinates", bbox_layout) return group @@ -760,6 +850,10 @@ def _connect_signals(self) -> None: # Visualization settings ## Colormap change self.cmap_combo.currentIndexChanged.connect(self._on_colormap_changed) + ## Skeleton change + self.show_skeleton_checkbox.stateChanged.connect(self._on_show_skeleton_changed) + self.skeleton_color_combo.currentIndexChanged.connect(self._on_skeleton_style_changed) + self.skeleton_thickness_spin.valueChanged.connect(self._on_skeleton_style_changed) ## Connect bounding box controls self.bbox_enabled_checkbox.stateChanged.connect(self._on_bbox_changed) self.bbox_x0_spin.valueChanged.connect(self._on_bbox_changed) @@ -839,6 +933,7 @@ def _apply_config(self, config: ApplicationSettings) -> None: self.bbox_y1_spin.setValue(bbox.y1) # Set visualization settings from config + ## Keypoints viz = config.visualization self._p_cutoff = viz.p_cutoff self._colormap = viz.colormap @@ -847,6 +942,9 @@ def _apply_config(self, config: ApplicationSettings) -> None: self._bbox_color = viz.get_bbox_color_bgr() if hasattr(self, "bbox_color_combo"): color_ui.set_bbox_combo_from_bgr(self.bbox_color_combo, self._bbox_color) + ## Skeleton + if resolved_model_path.strip(): + self._configure_skeleton_for_model(resolved_model_path) # Update DLC camera list self._refresh_dlc_camera_list() @@ -1018,6 +1116,7 @@ def _action_browse_model(self) -> None: return file_path = str(file_path) self.model_path_edit.setText(file_path) + self._configure_skeleton_for_model(file_path) # Persist model path + directory self._model_path_store.save_if_valid(file_path) @@ -1180,6 +1279,38 @@ def _on_bbox_color_changed(self, _index: int) -> None: if self._current_frame is not None: self._display_frame(self._current_frame, force=True) + def _on_show_skeleton_changed(self, _state: int) -> None: + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + if self._current_frame is not None: + self._display_frame(self._current_frame, force=True) + + def _on_skeleton_style_changed(self, _value: int = 0) -> None: + """Apply UI skeleton styling to the current Skeleton instance.""" + if self._skeleton is None: + return + + mode, color = color_ui.get_skeleton_style_from_combo( + self.skeleton_color_combo, + fallback_mode="solid", + fallback_color=self._skeleton.style.color, + ) + + # Update style mode + if mode == "gradient_keypoints": + self._skeleton.style.mode = skel.SkeletonColorMode.GRADIENT_KEYPOINTS + else: + self._skeleton.style.mode = skel.SkeletonColorMode.SOLID + if color is not None: + self._skeleton.style.color = tuple(color) + + # Thickness + self._skeleton.style.thickness = int(self.skeleton_thickness_spin.value()) + + # Redraw + if self._current_frame is not None: + self._display_frame(self._current_frame, force=True) + # ------------------------------------------------------------------ # Multi-camera def _open_camera_config_dialog(self) -> None: @@ -1338,6 +1469,18 @@ def _render_overlays_for_recording(self, cam_id, frame): offset=offset, scale=scale, ) + + if self._skeleton and hasattr(self, "show_skeleton_checkbox") and self.show_skeleton_checkbox.isChecked(): + pose_arr = np.asarray(self._last_pose.pose) + if pose_arr.ndim == 3: + st = self._skeleton.draw_many(output, pose_arr, self._p_cutoff, offset, scale) + else: + self._skeleton.draw(output, self._last_pose.pose, self._p_cutoff, offset, scale) + if st.should_disable: + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.blockSignals(False) + if self._bbox_enabled: output = draw_bbox( frame=output, @@ -1608,6 +1751,68 @@ def _stop_preview(self) -> None: self.camera_stats_label.setText("Camera idle") # self._show_logo_and_text() + def _sync_skeleton_controls_from_model(self) -> None: + """Enable and initialize skeleton UI controls from the current Skeleton.style.""" + enabled = self._skeleton is not None + + self.show_skeleton_checkbox.setEnabled(enabled) + self.skeleton_color_combo.setEnabled(enabled) + self.skeleton_thickness_spin.setEnabled(enabled) + + if not enabled: + return + + # Set thickness + self.skeleton_thickness_spin.blockSignals(True) + self.skeleton_thickness_spin.setValue(int(self._skeleton.style.thickness)) + self.skeleton_thickness_spin.blockSignals(False) + + # Set color/mode + mode = ( + "gradient_keypoints" if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS else "solid" + ) + color_ui.set_skeleton_combo_from_style( + self.skeleton_color_combo, + mode=mode, + color=self._skeleton.style.color, + ) + + if hasattr(self.skeleton_color_combo, "update_shrink_width"): + self.skeleton_color_combo.update_shrink_width() + + def _configure_skeleton_for_model(self, model_path: str) -> None: + """Select an appropriate skeleton definition for the currently configured model.""" + self._skeleton = None + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + + # Default: disable until we find a compatible skeleton + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.setEnabled(False) + # keep checked state but it won't be used unless enabled + + p = Path(model_path).expanduser() + + root = p if p.is_dir() else p.parent + cfg = root / "config.yaml" + if cfg.exists(): + try: + sk = skel.load_dlc_skeleton(cfg) + except Exception as e: + logger.warning(f"Failed to load DLC skeleton from {cfg}: {e}") + sk = None + + if sk is not None: + self._skeleton = sk + self._sync_skeleton_controls_from_model() + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.setEnabled(True) + self.statusBar().showMessage("Skeleton available: DLC config.yaml", 3000) + return + + # None found + self.statusBar().showMessage("No skeleton definition available for this model.", 3000) + def _configure_dlc(self) -> bool: try: settings = self._dlc_settings_from_ui() @@ -1639,6 +1844,7 @@ def _configure_dlc(self) -> bool: self.statusBar().showMessage(f"Processor selection ignored (control disabled): {selected_key}", 3000) self._dlc.configure(settings, processor=processor) + self._configure_skeleton_for_model(settings.model_path) self._model_path_store.save_if_valid(settings.model_path) return True @@ -1864,6 +2070,18 @@ def _stop_inference(self, show_message: bool = True) -> None: self._last_processor_vid_recording = False self._auto_record_session_name = None + # Reset skeleton + self._skeleton = None + self._skeleton_auto_disabled = False + self._last_skeleton_disable_msg = None + self.skeleton_color_combo.setEnabled(False) + self.skeleton_thickness_spin.setEnabled(False) + if hasattr(self, "show_skeleton_checkbox"): + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.setEnabled(False) + self.show_skeleton_checkbox.blockSignals(False) + # Reset button appearance self.start_inference_button.setText("Start pose inference") self.start_inference_button.setStyleSheet("") @@ -1908,6 +2126,59 @@ def _on_dlc_error(self, message: str) -> None: self._stop_inference(show_message=False) self._show_error(message) + def _try_draw_skeleton(self, overlay: np.ndarray, pose: np.ndarray) -> None: + if self._skeleton is None: + return + if not self.show_skeleton_checkbox.isChecked(): + return + if self._skeleton_auto_disabled: + return + + pose_arr = np.asarray(pose) + + # Compute keypoint colors only if gradient mode is active + kp_colors = None + try: + if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS: + n_kpts = pose_arr.shape[1] if pose_arr.ndim == 3 else pose_arr.shape[0] + kp_colors = keypoint_colors_bgr(self._colormap, int(n_kpts)) + + if pose_arr.ndim == 3: + status = self._skeleton.draw_many( + overlay, + pose_arr, + p_cutoff=self._p_cutoff, + offset=self._dlc_tile_offset, + scale=self._dlc_tile_scale, + keypoint_colors=kp_colors, + ) + else: + status = self._skeleton.draw( + overlay, + pose_arr, + p_cutoff=self._p_cutoff, + offset=self._dlc_tile_offset, + scale=self._dlc_tile_scale, + keypoint_colors=kp_colors, + ) + + except Exception as e: + status = skel.SkeletonRenderStatus( + code=skel.SkeletonRenderCode.POSE_SHAPE_INVALID, + message=f"Skeleton rendering error: {e}", + ) + + if status.should_disable: + self._skeleton_auto_disabled = True + msg = status.message or "Skeleton disabled due to keypoint mismatch." + if msg != self._last_skeleton_disable_msg: + self._last_skeleton_disable_msg = msg + self.statusBar().showMessage(f"Skeleton disabled: {msg}", 6000) + + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(False) + self.show_skeleton_checkbox.blockSignals(False) + def _update_video_display(self, frame: np.ndarray) -> None: display_frame = frame @@ -1921,6 +2192,8 @@ def _update_video_display(self, frame: np.ndarray) -> None: scale=self._dlc_tile_scale, ) + self._try_draw_skeleton(display_frame, self._last_pose.pose) + if self._bbox_enabled: display_frame = draw_bbox( display_frame, diff --git a/dlclivegui/gui/misc/color_dropdowns.py b/dlclivegui/gui/misc/color_dropdowns.py index bb0f0ac..c434ae2 100644 --- a/dlclivegui/gui/misc/color_dropdowns.py +++ b/dlclivegui/gui/misc/color_dropdowns.py @@ -16,7 +16,7 @@ import numpy as np from PySide6.QtCore import Qt -from PySide6.QtGui import QColor, QIcon, QImage, QPainter, QPixmap +from PySide6.QtGui import QBrush, QColor, QIcon, QImage, QLinearGradient, QPainter, QPixmap from PySide6.QtWidgets import ( QComboBox, QSizePolicy, @@ -28,6 +28,26 @@ TEnum = TypeVar("TEnum") +def make_gradient_swatch_icon(*, width: int = 40, height: int = 16, border: int = 1) -> QIcon: + """Small gradient swatch icon for 'Gradient' mode.""" + pix = QPixmap(width, height) + pix.fill(Qt.transparent) + p = QPainter(pix) + + # border/background + p.fillRect(0, 0, width, height, Qt.black) + p.fillRect(border, border, width - 2 * border, height - 2 * border, Qt.white) + + # inner gradient (blue -> red, arbitrary but clearly "gradient") + grad = QLinearGradient(border + 1, 0, width - (border + 1), 0) + grad.setColorAt(0.0, QColor(0, 140, 255)) + grad.setColorAt(1.0, QColor(255, 80, 0)) + p.fillRect(border + 1, border + 1, width - 2 * (border + 1), height - 2 * (border + 1), QBrush(grad)) + + p.end() + return QIcon(pix) + + # ----------------------------------------------------------------------------- # Combo sizing: shrink to current selection + wide popup # ----------------------------------------------------------------------------- @@ -413,3 +433,114 @@ def get_cmap_name_from_combo(combo: QComboBox, *, fallback: str = "viridis") -> return data text = combo.currentText().strip() return text or fallback + + +# ----------------------------------------------------------------------------- +# Skeleton color combo helpers (enum-based + gradient) +# ----------------------------------------------------------------------------- +def get_skeleton_style_from_combo( + combo: QComboBox, + *, + fallback_mode: str = "solid", + fallback_color: BGR | None = None, +) -> tuple[str, BGR | None]: + data = combo.currentData() + if isinstance(data, dict): + mode = data.get("mode", fallback_mode) + color = data.get("color", fallback_color) + return mode, color + return fallback_mode, fallback_color + + +def make_skeleton_color_combo( + colors_enum: Iterable[TEnum], + *, + current_mode: str = "solid", + current_color: BGR | None = (0, 255, 255), + include_icons: bool = True, + tooltip: str = "Select skeleton line color or Gradient (from keypoints)", + sizing: ComboSizing | None = None, +) -> QComboBox: + combo = ShrinkCurrentWidePopupComboBox(sizing=sizing) if sizing is not None else QComboBox() + combo.setToolTip(tooltip) + populate_skeleton_color_combo( + combo, + colors_enum, + current_mode=current_mode, + current_color=current_color, + include_icons=include_icons, + ) + if isinstance(combo, ShrinkCurrentWidePopupComboBox): + combo.update_shrink_width() + return combo + + +def set_skeleton_combo_from_style(combo: QComboBox, *, mode: str, color: BGR | None) -> None: + """Select the best matching item.""" + # Gradient + if mode == "gradient_keypoints": + combo.findData({"mode": "gradient_keypoints"}) # may fail due to dict identity + # robust fallback: scan + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "gradient_keypoints": + combo.setCurrentIndex(i) + return + return + + # Solid with color + if color is not None: + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "solid" and tuple(d.get("color")) == tuple(color): + combo.setCurrentIndex(i) + return + + # Default: first solid entry + for i in range(combo.count()): + d = combo.itemData(i) + if isinstance(d, dict) and d.get("mode") == "solid": + combo.setCurrentIndex(i) + return + + +def populate_skeleton_color_combo( + combo: QComboBox, + colors_enum: Iterable[TEnum], + *, + current_mode: str = "solid", + current_color: BGR | None = None, + include_icons: bool = True, + gradient_label: str = "Gradient (from keypoints)", +) -> None: + """ + Populate combo with: + - Gradient mode + - Solid colors (from enum values BGR) + ItemData is a dict with keys: + - mode: 'solid' or 'gradient_keypoints' + - color: optional BGR for solid + """ + combo.blockSignals(True) + combo.clear() + + # 1) Gradient option + if include_icons: + combo.addItem(make_gradient_swatch_icon(), gradient_label, {"mode": "gradient_keypoints"}) + else: + combo.addItem(gradient_label, {"mode": "gradient_keypoints"}) + + # 2) Solid colors + for enum_item in colors_enum: + bgr: BGR = enum_item.value + name = getattr(enum_item, "name", str(enum_item)).title() + data = {"mode": "solid", "color": bgr} + if include_icons: + combo.addItem(make_bgr_swatch_icon(bgr), name, data) + else: + combo.addItem(name, data) + + # Select current + set_skeleton_combo_from_style(combo, mode=current_mode, color=current_color) + + combo.blockSignals(False) diff --git a/dlclivegui/temp/yolo/__init__.py b/dlclivegui/temp/yolo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dlclivegui/utils/display.py b/dlclivegui/utils/display.py index 0eac657..05c3408 100644 --- a/dlclivegui/utils/display.py +++ b/dlclivegui/utils/display.py @@ -8,7 +8,7 @@ import numpy as np -class BBoxColors(enum.Enum): +class PrimaryColors(enum.Enum): RED = (0, 0, 255) GREEN = (0, 255, 0) BLUE = (255, 0, 0) @@ -20,7 +20,26 @@ class BBoxColors(enum.Enum): @staticmethod def get_all_display_names() -> list[str]: - return [color.name.capitalize() for color in BBoxColors] + return [c.name.capitalize() for c in PrimaryColors] + + +BBoxColors = PrimaryColors + + +class SkeletonColors(enum.Enum): + GRADIENT = "gradient" # special mode + RED = PrimaryColors.RED.value + GREEN = PrimaryColors.GREEN.value + BLUE = PrimaryColors.BLUE.value + YELLOW = PrimaryColors.YELLOW.value + CYAN = PrimaryColors.CYAN.value + MAGENTA = PrimaryColors.MAGENTA.value + WHITE = PrimaryColors.WHITE.value + BLACK = PrimaryColors.BLACK.value + + @staticmethod + def get_all_display_names() -> list[str]: + return ["Gradient"] + [c.name.capitalize() for c in SkeletonColors if c != SkeletonColors.GRADIENT] def color_to_rgb(color_name: str) -> tuple[int, int, int]: @@ -31,6 +50,21 @@ def color_to_rgb(color_name: str) -> tuple[int, int, int]: raise ValueError(f"Unknown color name: {color_name}") from None +def keypoint_colors_bgr(colormap: str, num_keypoints: int) -> list[tuple[int, int, int]]: + """ + Return the exact BGR colors used by draw_keypoints() for a given Matplotlib colormap + and number of keypoints. + """ + cmap = plt.get_cmap(colormap) + colors: list[tuple[int, int, int]] = [] + for idx in range(num_keypoints): + t = idx / max(num_keypoints - 1, 1) + rgba = cmap(t) + bgr = (int(rgba[2] * 255), int(rgba[1] * 255), int(rgba[0] * 255)) + colors.append(bgr) + return colors + + def compute_tiling_geometry( frames: dict[str, np.ndarray], max_canvas: tuple[int, int] = (1200, 800), diff --git a/dlclivegui/utils/skeleton.py b/dlclivegui/utils/skeleton.py new file mode 100644 index 0000000..3904572 --- /dev/null +++ b/dlclivegui/utils/skeleton.py @@ -0,0 +1,384 @@ +# dlclivegui/utils/skeleton.py +from __future__ import annotations + +import json +from dataclasses import dataclass +from enum import Enum, auto +from pathlib import Path + +import cv2 +import numpy as np +import yaml +from pydantic import BaseModel, Field, ValidationError, field_validator + + +# ############### # +# Status & code # +# ############### # +class SkeletonRenderCode(Enum): + OK = auto() + POSE_SHAPE_INVALID = auto() + KEYPOINT_COUNT_MISMATCH = auto() + + +@dataclass(frozen=True) +class SkeletonRenderStatus: + code: SkeletonRenderCode + message: str = "" + + @property + def rendered(self) -> bool: + return self.code == SkeletonRenderCode.OK + + @property + def should_disable(self) -> bool: + # GUI can switch off skeleton drawing if True + return self.code in { + SkeletonRenderCode.POSE_SHAPE_INVALID, + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + } + + +# ############ # +# Exceptions # +# ############ # + + +class SkeletonError(ValueError): + """Raised when a skeleton definition is invalid.""" + + +class SkeletonLoadError(Exception): + """High-level skeleton loading error (safe for GUI display).""" + + +class SkeletonValidationError(SkeletonLoadError): + """Schema or semantic validation error.""" + + +# ################## # +# Skeleton display # +# ################## # + +BGR = tuple[int, int, int] # (B, G, R) color format + + +class SkeletonColorMode(str, Enum): + SOLID = "solid" + GRADIENT_KEYPOINTS = "gradient_keypoints" # use endpoint keypoint colors + + +@dataclass +class SkeletonStyle: + mode: SkeletonColorMode = SkeletonColorMode.SOLID + color: BGR = (0, 255, 255) # default if SOLID + thickness: int = 2 # base thickness in pixels + gradient_steps: int = 16 # segments per edge when gradient + scale_with_zoom: bool = True # scale thickness with (sx, sy) + + def effective_thickness(self, sx: float, sy: float) -> int: + if not self.scale_with_zoom: + return max(1, int(self.thickness)) + return max(1, int(round(self.thickness * min(sx, sy)))) + + +class SkeletonStyleModel(BaseModel): + mode: SkeletonColorMode = SkeletonColorMode.SOLID + color: BGR = (0, 255, 255) # default if SOLID + thickness: int = Field(2, ge=1, description="Base thickness in pixels") + gradient_steps: int = Field(16, ge=2, description="Segments per edge when gradient") + scale_with_zoom: bool = True + + @field_validator("thickness") + @classmethod + def _thickness_positive(cls, v): + if v < 1: + raise ValueError("Thickness must be at least 1 pixel") + return v + + @field_validator("gradient_steps") + @classmethod + def _steps_positive(cls, v): + if v < 2: + raise ValueError("gradient_steps must be >= 2") + return v + + +# ############# # +# Skeleton IO # +# ############# # +class SkeletonModel(BaseModel): + """Validated skeleton definition (IO + schema).""" + + name: str | None = None + + keypoints: list[str] = Field(..., min_length=1, description="Ordered list of keypoint names") + + edges: list[tuple[int, int]] = Field( + default_factory=list, + description="List of (i, j) keypoint index pairs", + ) + + style: SkeletonStyleModel = Field(default_factory=SkeletonStyleModel) + default_color: BGR = (0, 255, 255) # used if style.color is None or in SOLID mode + edge_colors: dict[tuple[int, int], BGR] = Field(default_factory=dict) + + schema_version: int = 1 + + @field_validator("keypoints") + @classmethod + def validate_unique_keypoints(cls, v): + if len(set(v)) != len(v): + raise ValueError("Duplicate keypoint names detected") + return v + + @field_validator("edges") + @classmethod + def validate_edges(cls, edges, info): + keypoints = info.data.get("keypoints", []) + n = len(keypoints) + + for i, j in edges: + if i == j: + raise ValueError(f"Self-loop detected in edge ({i}, {j})") + if not (0 <= i < n and 0 <= j < n): + raise ValueError(f"Edge ({i}, {j}) out of range for {n} keypoints") + return edges + + +def _load_raw_skeleton_data(path: Path) -> dict: + if not path.exists(): + raise SkeletonLoadError(f"Skeleton file not found: {path}") + + if path.suffix in {".yaml", ".yml"}: + return yaml.safe_load(path.read_text()) + + if path.suffix == ".json": + return json.loads(path.read_text()) + + raise SkeletonLoadError(f"Unsupported file type: {path.suffix}") + + +def _format_pydantic_error(err: ValidationError) -> str: + lines = ["Invalid skeleton definition:"] + for e in err.errors(): + loc = " → ".join(map(str, e["loc"])) + msg = e["msg"] + lines.append(f"• {loc}: {msg}") + return "\n".join(lines) + + +def load_skeleton(path: Path) -> Skeleton: + try: + data = _load_raw_skeleton_data(path) + model = SkeletonModel.model_validate(data) + return Skeleton(model) + + except ValidationError as e: + raise SkeletonValidationError(_format_pydantic_error(e)) from None + + except Exception as e: + raise SkeletonLoadError(str(e)) from None + + +def save_skeleton(path: Path, model: SkeletonModel) -> None: + data = model.model_dump() + + if path.suffix in {".yaml", ".yml"}: + path.write_text(yaml.safe_dump(data, sort_keys=False)) + elif path.suffix == ".json": + path.write_text(json.dumps(data, indent=2)) + else: + raise SkeletonLoadError(f"Unsupported skeleton file type: {path.suffix}") + + +def load_dlc_skeleton(config_path: Path) -> Skeleton | None: + if not config_path.exists(): + raise SkeletonLoadError(f"DLC config not found: {config_path}") + + cfg = yaml.safe_load(config_path.read_text()) + + bodyparts = cfg.get("bodyparts") + if not bodyparts: + return None # No pose info + + edges = [] + + # Newer DLC format + if "skeleton" in cfg: + for a, b in cfg["skeleton"]: + edges.append((bodyparts.index(a), bodyparts.index(b))) + + # Older / alternative formats + elif "skeleton_edges" in cfg: + edges = [tuple(e) for e in cfg["skeleton_edges"]] + + if not edges: + return None + + model = SkeletonModel( + name=cfg.get("Task", "DeepLabCut"), + keypoints=bodyparts, + edges=edges, + ) + + return Skeleton(model) + + +class Skeleton: + """Runtime skeleton optimized for drawing.""" + + def __init__(self, model: SkeletonModel): + self.name = model.name + self.keypoints = model.keypoints + self.edges = model.edges + + self.style = SkeletonStyle( + mode=model.style.mode, + color=model.style.color, + thickness=model.style.thickness, + gradient_steps=model.style.gradient_steps, + scale_with_zoom=model.style.scale_with_zoom, + ) + self.default_color = model.default_color + self.edge_colors = model.edge_colors + + def check_pose_compat(self, pose: np.ndarray) -> SkeletonRenderStatus: + pose = np.asarray(pose) + + if pose.ndim != 2 or pose.shape[1] not in (2, 3): + return SkeletonRenderStatus( + SkeletonRenderCode.POSE_SHAPE_INVALID, + f"Pose must be (N,2) or (N,3); got shape={pose.shape}", + ) + + expected = len(self.keypoints) + got = pose.shape[0] + if got != expected: + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Skeleton expects {expected} keypoints, but pose has {got}.", + ) + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "") + + def _draw_gradient_edge( + self, + img: np.ndarray, + p1: tuple[int, int], + p2: tuple[int, int], + c1: BGR, + c2: BGR, + thickness: int, + steps: int, + ): + x1, y1 = p1 + x2, y2 = p2 + + for s in range(steps): + a0 = s / steps + a1 = (s + 1) / steps + xs0 = int(x1 + (x2 - x1) * a0) + ys0 = int(y1 + (y2 - y1) * a0) + xs1 = int(x1 + (x2 - x1) * a1) + ys1 = int(y1 + (y2 - y1) * a1) + + t = (s + 0.5) / steps + b = int(c1[0] + (c2[0] - c1[0]) * t) + g = int(c1[1] + (c2[1] - c1[1]) * t) + r = int(c1[2] + (c2[2] - c1[2]) * t) + + cv2.line(img, (xs0, ys0), (xs1, ys1), (b, g, r), thickness, lineType=cv2.LINE_AA) + + def draw( + self, + overlay: np.ndarray, + pose: np.ndarray, + p_cutoff: float, + offset: tuple[int, int], + scale: tuple[float, float], + *, + style: SkeletonStyle | None = None, + color_override: BGR | None = None, + keypoint_colors: list[BGR] | None = None, + ) -> SkeletonRenderStatus: + status = self.check_pose_compat(pose) + if not status.rendered: + return status + + st = style or self.style + ox, oy = offset + sx, sy = scale + th = st.effective_thickness(sx, sy) + + # if gradient mode, require keypoint_colors aligned with keypoint order + if st.mode == SkeletonColorMode.GRADIENT_KEYPOINTS: + if keypoint_colors is None or len(keypoint_colors) != len(self.keypoints): + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Gradient mode requires keypoint_colors of length {len(self.keypoints)}.", + ) + + for i, j in self.edges: + xi, yi = pose[i][:2] + xj, yj = pose[j][:2] + ci = pose[i][2] if pose.shape[1] > 2 else 1.0 + cj = pose[j][2] if pose.shape[1] > 2 else 1.0 + if np.isnan(xi) or np.isnan(yi) or ci < p_cutoff or np.isnan(xj) or np.isnan(yj) or cj < p_cutoff: + continue + + p1 = (int(xi * sx + ox), int(yi * sy + oy)) + p2 = (int(xj * sx + ox), int(yj * sy + oy)) + + if st.mode == SkeletonColorMode.GRADIENT_KEYPOINTS: + c1 = keypoint_colors[i] + c2 = keypoint_colors[j] + self._draw_gradient_edge(overlay, p1, p2, c1, c2, th, st.gradient_steps) + else: + # SOLID: priority edge_colors > override > style.color > default_color + color = self.edge_colors.get((i, j), color_override or st.color or self.default_color) + cv2.line(overlay, p1, p2, color, th, lineType=cv2.LINE_AA) + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "") + + def draw_many( + self, + overlay: np.ndarray, + poses: np.ndarray, + p_cutoff: float, + offset: tuple[int, int], + scale: tuple[float, float], + *, + style: SkeletonStyle | None = None, + color_override: BGR | None = None, + keypoint_colors: list[BGR] | None = None, + ) -> SkeletonRenderStatus: + poses = np.asarray(poses) + if poses.ndim != 3: + return SkeletonRenderStatus( + SkeletonRenderCode.POSE_SHAPE_INVALID, + f"Multi-pose must be (A,N,2/3); got shape={poses.shape}", + ) + + expected = len(self.keypoints) + if poses.shape[1] != expected: + return SkeletonRenderStatus( + SkeletonRenderCode.KEYPOINT_COUNT_MISMATCH, + f"Skeleton expects {expected} keypoints, but poses have N={poses.shape[1]}.", + ) + + for pose in poses: + st = self.draw( + overlay, + pose, + p_cutoff, + offset, + scale, + style=style, + color_override=color_override, + keypoint_colors=keypoint_colors, + ) + if not st.rendered: + return st + + return SkeletonRenderStatus(SkeletonRenderCode.OK, "") From 4bcacb7693dc72da6ab747c90286fe17893c2d88 Mon Sep 17 00:00:00 2001 From: Cyril Achard Date: Wed, 25 Feb 2026 11:35:29 +0100 Subject: [PATCH 2/2] Add configurable skeleton styling and UI integration Introduce a first-class SkeletonStyle and SkeletonColorMode (and BGR alias) in config, and wire them through the GUI and skeleton utilities. Updates include: - dlclivegui/config.py: add BGR type, SkeletonColorMode enum, Pydantic SkeletonStyle model, expose skeleton fields on VisualizationSettings and with_overlays on RecordingSettings, and helper color accessors. - dlclivegui/gui/main_window.py: add _apply_viz_settings_to_ui/_apply_viz_settings_to_skeleton, create a unified _draw_skeleton_on_frame renderer, read/write skeleton style from UI, refactor and simplify skeleton enable/disable and model heuristics, and wire recording.with_overlays. - dlclivegui/gui/misc/color_dropdowns.py: reuse BGR from config. - dlclivegui/utils/skeleton.py: remove duplicate style/type definitions, import style and enums from config, and add module docstring. These changes centralize skeleton styling, enable UI control and persistence of style, and clean up duplicate definitions across modules. --- dlclivegui/config.py | 30 +++++ dlclivegui/gui/main_window.py | 161 ++++++++++++++++++------- dlclivegui/gui/misc/color_dropdowns.py | 3 +- dlclivegui/utils/skeleton.py | 25 +--- 4 files changed, 151 insertions(+), 68 deletions(-) diff --git a/dlclivegui/config.py b/dlclivegui/config.py index 6d9e1de..8d56b12 100644 --- a/dlclivegui/config.py +++ b/dlclivegui/config.py @@ -12,6 +12,25 @@ TileLayout = Literal["auto", "2x2", "1x4", "4x1"] Precision = Literal["FP32", "FP16"] ModelType = Literal["pytorch", "tensorflow"] +BGR = tuple[int, int, int] # (B, G, R) color format + + +class SkeletonColorMode(str, Enum): + SOLID = "solid" + GRADIENT_KEYPOINTS = "gradient_keypoints" # use endpoint keypoint colors + + +class SkeletonStyle(BaseModel): + mode: SkeletonColorMode = SkeletonColorMode.SOLID + color: BGR = (0, 255, 255) # default if SOLID + thickness: int = 2 # base thickness in pixels + gradient_steps: int = 16 # segments per edge when gradient + scale_with_zoom: bool = True # scale thickness with (sx, sy) + + def effective_thickness(self, sx: float, sy: float) -> int: + if not self.scale_with_zoom: + return max(1, int(self.thickness)) + return max(1, int(round(self.thickness * min(sx, sy)))) class CameraSettings(BaseModel): @@ -301,12 +320,22 @@ class VisualizationSettings(BaseModel): colormap: str = "hot" bbox_color: tuple[int, int, int] = (0, 0, 255) + show_pose: bool = True + show_skeleton: bool = False + skeleton_style: SkeletonStyle = Field(default_factory=SkeletonStyle) + def get_bbox_color_bgr(self) -> tuple[int, int, int]: """Get bounding box color in BGR format""" if isinstance(self.bbox_color, (list, tuple)) and len(self.bbox_color) == 3: return tuple(int(c) for c in self.bbox_color) return (0, 0, 255) # default red + def get_skeleton_color_bgr(self) -> tuple[int, int, int]: + c = self.skeleton_style.color + if isinstance(c, (list, tuple)) and len(c) == 3: + return tuple(int(v) for v in c) + return (0, 255, 255) # default yellow + class RecordingSettings(BaseModel): enabled: bool = False @@ -315,6 +344,7 @@ class RecordingSettings(BaseModel): container: Literal["mp4", "avi", "mov"] = "mp4" codec: str = "libx264" crf: int = Field(default=23, ge=0, le=51) + with_overlays: bool = False def output_path(self) -> Path: """Return the absolute output path for recordings.""" diff --git a/dlclivegui/gui/main_window.py b/dlclivegui/gui/main_window.py index 798870d..87b810c 100644 --- a/dlclivegui/gui/main_window.py +++ b/dlclivegui/gui/main_window.py @@ -63,6 +63,7 @@ VisualizationSettings, ) +from ..config import SkeletonColorMode, SkeletonStyle from ..processors.processor_utils import ( default_processors_dir, instantiate_from_scan, @@ -891,6 +892,45 @@ def _connect_signals(self) -> None: # ------------------------------------------------------------------ # Config + # ------------------------------------------------------------------ + def _apply_viz_settings_to_ui(self, viz: VisualizationSettings) -> None: + """Set UI state from VisualizationSettings (does not require skeleton to exist).""" + # Pose toggle + self.show_predictions_checkbox.blockSignals(True) + self.show_predictions_checkbox.setChecked(bool(viz.show_pose)) + self.show_predictions_checkbox.blockSignals(False) + + # Skeleton toggle (may remain disabled until skeleton exists) + self.show_skeleton_checkbox.blockSignals(True) + self.show_skeleton_checkbox.setChecked(bool(viz.show_skeleton)) + self.show_skeleton_checkbox.blockSignals(False) + + # Skeleton style controls (combo/spin) - set values even if disabled + if hasattr(self, "skeleton_color_combo"): + mode = viz.skeleton_style.mode.value # "solid" or "gradient_keypoints" + color = tuple(viz.skeleton_style.color) + color_ui.set_skeleton_combo_from_style(self.skeleton_color_combo, mode=mode, color=color) + + if hasattr(self, "skeleton_thickness_spin"): + self.skeleton_thickness_spin.blockSignals(True) + self.skeleton_thickness_spin.setValue(int(viz.skeleton_style.thickness)) + self.skeleton_thickness_spin.blockSignals(False) + + def _apply_viz_settings_to_skeleton(self, viz: VisualizationSettings) -> None: + """Apply VisualizationSettings onto the active runtime Skeleton, if present.""" + if self._skeleton is None: + return + + # Copy style fields + self._skeleton.style.mode = skel.SkeletonColorMode(viz.skeleton_style.mode.value) + self._skeleton.style.color = tuple(viz.skeleton_style.color) + self._skeleton.style.thickness = int(viz.skeleton_style.thickness) + self._skeleton.style.gradient_steps = int(viz.skeleton_style.gradient_steps) + self._skeleton.style.scale_with_zoom = bool(viz.skeleton_style.scale_with_zoom) + + # Enable/disable UI controls now that skeleton exists + self._sync_skeleton_controls_from_model() + def _apply_config(self, config: ApplicationSettings) -> None: # Update active cameras label self._update_active_cameras_label() @@ -907,6 +947,7 @@ def _apply_config(self, config: ApplicationSettings) -> None: self.output_directory_edit.setText(recording.directory) self.filename_edit.setText(recording.filename) self.container_combo.setCurrentText(recording.container) + self.record_with_overlays_checkbox.setChecked(recording.with_overlays) codec_index = self.codec_combo.findText(recording.codec) if codec_index >= 0: self.codec_combo.setCurrentIndex(codec_index) @@ -937,6 +978,7 @@ def _apply_config(self, config: ApplicationSettings) -> None: viz = config.visualization self._p_cutoff = viz.p_cutoff self._colormap = viz.colormap + self._apply_viz_settings_to_ui(viz) if hasattr(self, "cmap_combo"): color_ui.set_cmap_combo_from_name(self.cmap_combo, self._colormap, fallback="viridis") self._bbox_color = viz.get_bbox_color_bgr() @@ -945,6 +987,7 @@ def _apply_config(self, config: ApplicationSettings) -> None: ## Skeleton if resolved_model_path.strip(): self._configure_skeleton_for_model(resolved_model_path) + self._apply_viz_settings_to_skeleton(viz) # Update DLC camera list self._refresh_dlc_camera_list() @@ -1007,6 +1050,7 @@ def _recording_settings_from_ui(self) -> RecordingSettings: container=self.container_combo.currentText().strip() or "mp4", codec=self.codec_combo.currentText().strip() or "libx264", crf=int(self.crf_spin.value()), + with_overlays=self.record_with_overlays_checkbox.isChecked(), ) def _bbox_settings_from_ui(self) -> BoundingBoxSettings: @@ -1019,10 +1063,29 @@ def _bbox_settings_from_ui(self) -> BoundingBoxSettings: ) def _visualization_settings_from_ui(self) -> VisualizationSettings: + # Read skeleton mode+color from combo + mode_str, color = color_ui.get_skeleton_style_from_combo( + self.skeleton_color_combo, + fallback_mode="solid", + fallback_color=(0, 255, 255), + ) + + # Build SkeletonStyle (pydantic) + style = SkeletonStyle( + mode=SkeletonColorMode(mode_str), # or SkeletonColorMode.GRADIENT_KEYPOINTS if mode_str matches + color=tuple(color) if color else (0, 255, 255), + thickness=int(self.skeleton_thickness_spin.value()), + gradient_steps=getattr(self._skeleton.style, "gradient_steps", 16) if self._skeleton else 16, + scale_with_zoom=getattr(self._skeleton.style, "scale_with_zoom", True) if self._skeleton else True, + ) + return VisualizationSettings( p_cutoff=self._p_cutoff, colormap=self._colormap, bbox_color=self._bbox_color, + show_pose=self.show_predictions_checkbox.isChecked(), + show_skeleton=self.show_skeleton_checkbox.isChecked(), + skeleton_style=style, ) # ------------------------------------------------------------------ @@ -1286,28 +1349,16 @@ def _on_show_skeleton_changed(self, _state: int) -> None: self._display_frame(self._current_frame, force=True) def _on_skeleton_style_changed(self, _value: int = 0) -> None: - """Apply UI skeleton styling to the current Skeleton instance.""" if self._skeleton is None: return - mode, color = color_ui.get_skeleton_style_from_combo( - self.skeleton_color_combo, - fallback_mode="solid", - fallback_color=self._skeleton.style.color, - ) - - # Update style mode - if mode == "gradient_keypoints": - self._skeleton.style.mode = skel.SkeletonColorMode.GRADIENT_KEYPOINTS - else: - self._skeleton.style.mode = skel.SkeletonColorMode.SOLID - if color is not None: - self._skeleton.style.color = tuple(color) + mode_str, color = color_ui.get_skeleton_style_from_combo(self.skeleton_color_combo) + self._skeleton.style.mode = skel.SkeletonColorMode(mode_str) + if self._skeleton.style.mode == skel.SkeletonColorMode.SOLID and color is not None: + self._skeleton.style.color = tuple(color) - # Thickness self._skeleton.style.thickness = int(self.skeleton_thickness_spin.value()) - # Redraw if self._current_frame is not None: self._display_frame(self._current_frame, force=True) @@ -1469,17 +1520,12 @@ def _render_overlays_for_recording(self, cam_id, frame): offset=offset, scale=scale, ) - - if self._skeleton and hasattr(self, "show_skeleton_checkbox") and self.show_skeleton_checkbox.isChecked(): - pose_arr = np.asarray(self._last_pose.pose) - if pose_arr.ndim == 3: - st = self._skeleton.draw_many(output, pose_arr, self._p_cutoff, offset, scale) - else: - self._skeleton.draw(output, self._last_pose.pose, self._p_cutoff, offset, scale) - if st.should_disable: - self.show_skeleton_checkbox.blockSignals(True) - self.show_skeleton_checkbox.setChecked(False) - self.show_skeleton_checkbox.blockSignals(False) + self._draw_skeleton_on_frame( + output, + self._last_pose.pose, + offset=offset, + scale=scale, + ) if self._bbox_enabled: output = draw_bbox( @@ -1795,7 +1841,7 @@ def _configure_skeleton_for_model(self, model_path: str) -> None: root = p if p.is_dir() else p.parent cfg = root / "config.yaml" - if cfg.exists(): + if cfg.exists() and self._skeleton is None: try: sk = skel.load_dlc_skeleton(cfg) except Exception as e: @@ -1808,7 +1854,15 @@ def _configure_skeleton_for_model(self, model_path: str) -> None: if hasattr(self, "show_skeleton_checkbox"): self.show_skeleton_checkbox.setEnabled(True) self.statusBar().showMessage("Skeleton available: DLC config.yaml", 3000) - return + + if self._skeleton is not None: + try: + viz = self._config.visualization + self._apply_viz_settings_to_skeleton(viz) + except Exception as e: + logger.warning(f"Failed to apply visualization settings to skeleton: {e}") + pass + return # None found self.statusBar().showMessage("No skeleton definition available for this model.", 3000) @@ -2126,30 +2180,39 @@ def _on_dlc_error(self, message: str) -> None: self._stop_inference(show_message=False) self._show_error(message) - def _try_draw_skeleton(self, overlay: np.ndarray, pose: np.ndarray) -> None: + def _draw_skeleton_on_frame( + self, + overlay: np.ndarray, + pose: np.ndarray, + *, + offset: tuple[int, int], + scale: tuple[float, float], + allow_auto_disable: bool = True, + ) -> skel.SkeletonRenderStatus | None: + """Draw skeleton on overlay with correct style. Optionally auto-disables UI on mismatch.""" if self._skeleton is None: - return + return None if not self.show_skeleton_checkbox.isChecked(): - return + return None if self._skeleton_auto_disabled: - return + return None pose_arr = np.asarray(pose) - # Compute keypoint colors only if gradient mode is active + # Provide keypoint_colors iff gradient mode is active kp_colors = None - try: - if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS: - n_kpts = pose_arr.shape[1] if pose_arr.ndim == 3 else pose_arr.shape[0] - kp_colors = keypoint_colors_bgr(self._colormap, int(n_kpts)) + if self._skeleton.style.mode == skel.SkeletonColorMode.GRADIENT_KEYPOINTS: + n_kpts = pose_arr.shape[1] if pose_arr.ndim == 3 else pose_arr.shape[0] + kp_colors = keypoint_colors_bgr(self._colormap, int(n_kpts)) + try: if pose_arr.ndim == 3: status = self._skeleton.draw_many( overlay, pose_arr, p_cutoff=self._p_cutoff, - offset=self._dlc_tile_offset, - scale=self._dlc_tile_scale, + offset=offset, + scale=scale, keypoint_colors=kp_colors, ) else: @@ -2157,20 +2220,19 @@ def _try_draw_skeleton(self, overlay: np.ndarray, pose: np.ndarray) -> None: overlay, pose_arr, p_cutoff=self._p_cutoff, - offset=self._dlc_tile_offset, - scale=self._dlc_tile_scale, + offset=offset, + scale=scale, keypoint_colors=kp_colors, ) - except Exception as e: status = skel.SkeletonRenderStatus( code=skel.SkeletonRenderCode.POSE_SHAPE_INVALID, message=f"Skeleton rendering error: {e}", ) - if status.should_disable: + if allow_auto_disable and status.should_disable: self._skeleton_auto_disabled = True - msg = status.message or "Skeleton disabled due to keypoint mismatch." + msg = status.message or "Skeleton disabled due to mismatch." if msg != self._last_skeleton_disable_msg: self._last_skeleton_disable_msg = msg self.statusBar().showMessage(f"Skeleton disabled: {msg}", 6000) @@ -2179,6 +2241,8 @@ def _try_draw_skeleton(self, overlay: np.ndarray, pose: np.ndarray) -> None: self.show_skeleton_checkbox.setChecked(False) self.show_skeleton_checkbox.blockSignals(False) + return status + def _update_video_display(self, frame: np.ndarray) -> None: display_frame = frame @@ -2192,7 +2256,12 @@ def _update_video_display(self, frame: np.ndarray) -> None: scale=self._dlc_tile_scale, ) - self._try_draw_skeleton(display_frame, self._last_pose.pose) + self._draw_skeleton_on_frame( + display_frame, + self._last_pose.pose, + offset=self._dlc_tile_offset, + scale=self._dlc_tile_scale, + ) if self._bbox_enabled: display_frame = draw_bbox( diff --git a/dlclivegui/gui/misc/color_dropdowns.py b/dlclivegui/gui/misc/color_dropdowns.py index c434ae2..6f1351a 100644 --- a/dlclivegui/gui/misc/color_dropdowns.py +++ b/dlclivegui/gui/misc/color_dropdowns.py @@ -24,7 +24,8 @@ QStyleOptionComboBox, ) -BGR = tuple[int, int, int] +from dlclivegui.config import BGR + TEnum = TypeVar("TEnum") diff --git a/dlclivegui/utils/skeleton.py b/dlclivegui/utils/skeleton.py index 3904572..1bacbf2 100644 --- a/dlclivegui/utils/skeleton.py +++ b/dlclivegui/utils/skeleton.py @@ -1,3 +1,5 @@ +"""Skeleton definition, validation, and drawing utilities.""" + # dlclivegui/utils/skeleton.py from __future__ import annotations @@ -11,6 +13,8 @@ import yaml from pydantic import BaseModel, Field, ValidationError, field_validator +from dlclivegui.config import BGR, SkeletonColorMode, SkeletonStyle + # ############### # # Status & code # @@ -60,27 +64,6 @@ class SkeletonValidationError(SkeletonLoadError): # Skeleton display # # ################## # -BGR = tuple[int, int, int] # (B, G, R) color format - - -class SkeletonColorMode(str, Enum): - SOLID = "solid" - GRADIENT_KEYPOINTS = "gradient_keypoints" # use endpoint keypoint colors - - -@dataclass -class SkeletonStyle: - mode: SkeletonColorMode = SkeletonColorMode.SOLID - color: BGR = (0, 255, 255) # default if SOLID - thickness: int = 2 # base thickness in pixels - gradient_steps: int = 16 # segments per edge when gradient - scale_with_zoom: bool = True # scale thickness with (sx, sy) - - def effective_thickness(self, sx: float, sy: float) -> int: - if not self.scale_with_zoom: - return max(1, int(self.thickness)) - return max(1, int(round(self.thickness * min(sx, sy)))) - class SkeletonStyleModel(BaseModel): mode: SkeletonColorMode = SkeletonColorMode.SOLID