Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Fall Detection Gradio App (Batch Processing Pipeline) | |
| YOLOv11-Pose + ST-GCN 2-stage ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ๋์ ๊ฐ์ง ๋ฐ๋ชจ์ ๋๋ค. | |
| ๋ฐฐ์น ์ฒ๋ฆฌ๋ก ์ต์ ํ๋์ด ๋น ๋ฅธ ์ถ๋ก ์๋๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
| Pipeline: | |
| 1. decord๋ก ์ ์ฒด ํ๋ ์ ๋ฐฐ์น ๋ก๋ | |
| 2. YOLO Pose ๋ฐฐ์น ์ถ๋ก โ keypoints ๋์ | |
| 3. ์๋์ฐ ๋จ์ ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| 4. ๋์ ์์ -1s ~ +2s ๊ตฌ๊ฐ๋ง ์๊ฐํ | |
| ์ฌ์ฉ๋ฒ (๋ก์ปฌ): | |
| python pipeline/demo_gradio/app.py | |
| ์์ฑ์: Fall Detection Pipeline Team | |
| ์์ฑ์ผ: 2025-11-27 | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| from concurrent.futures import ProcessPoolExecutor | |
| from pathlib import Path | |
| from typing import Iterable, Optional, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from huggingface_hub import hf_hub_download | |
| # ํ๋ก์ ํธ ๋ฃจํธ๋ฅผ Python path์ ์ถ๊ฐ | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| # Zero GPU ํธํ ์ค์ | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| # ----------------------------------------------------------------------------- | |
| # ์ปค์คํ ํ ๋ง (PRITHIVSAKTHIUR ์คํ์ผ) | |
| # ----------------------------------------------------------------------------- | |
| colors.custom_color = colors.Color( | |
| name="custom_color", | |
| c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", | |
| c300="#7DB3D2", c400="#529AC3", c500="#4682B4", | |
| c600="#3E72A0", c700="#36638C", c800="#2E5378", | |
| c900="#264364", c950="#1E3450", | |
| ) | |
| class CustomTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.custom_color, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| button_primary_text_color="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| slider_color="*secondary_500", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| ) | |
| custom_theme = CustomTheme() | |
| # ----------------------------------------------------------------------------- | |
| # CSS ์คํ์ผ | |
| # ----------------------------------------------------------------------------- | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 1200px; } | |
| #main-title h1 { font-size: 2.3em !important; } | |
| .submit-btn { | |
| background-color: #4682B4 !important; | |
| color: white !important; | |
| } | |
| .submit-btn:hover { | |
| background-color: #5A9BD4 !important; | |
| } | |
| .result-label { | |
| font-size: 1.5em !important; | |
| font-weight: bold !important; | |
| padding: 10px !important; | |
| border-radius: 8px !important; | |
| } | |
| .fall-detected { | |
| background-color: #FF4444 !important; | |
| color: white !important; | |
| } | |
| .non-fall { | |
| background-color: #44BB44 !important; | |
| color: white !important; | |
| } | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # ๋๋ฐ์ด์ค ์ค์ | |
| # ----------------------------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------------------------------------------------------- | |
| # GPU ๋ฐ์ฝ๋ ์ดํฐ (๋ก์ปฌ/HF Spaces ํธํ) | |
| # ----------------------------------------------------------------------------- | |
| def gpu_decorator(duration: int = 120): | |
| """๋ก์ปฌ์์๋ ๊ทธ๋ฅ ์คํ, Spaces์์๋ GPU ํ ๋น""" | |
| def decorator(func): | |
| if SPACES_AVAILABLE: | |
| return spaces.GPU(duration=duration)(func) | |
| return func | |
| return decorator | |
| # ----------------------------------------------------------------------------- | |
| # ๋ชจ๋ธ ๋ค์ด๋ก๋ (HuggingFace Hub) | |
| # ----------------------------------------------------------------------------- | |
| HF_MODEL_REPO = "YoungjaeDev/fall-detection-models" | |
| def download_models() -> tuple[str, str]: | |
| """HuggingFace Hub์์ ๋ชจ๋ธ ๋ค์ด๋ก๋ (์บ์๋จ)""" | |
| # ๋ก์ปฌ ๊ฒฝ๋ก ์ฐ์ ํ์ธ (๊ฐ๋ฐ ํ๊ฒฝ) | |
| local_pose = Path("yolo11m-pose.pt") | |
| local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth") | |
| if local_pose.exists() and local_stgcn.exists(): | |
| return str(local_pose), str(local_stgcn) | |
| # HuggingFace Hub์์ ๋ค์ด๋ก๋ | |
| token = os.environ.get("HF_TOKEN") | |
| if token is None: | |
| raise RuntimeError( | |
| "HF_TOKEN ํ๊ฒฝ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. " | |
| "Private ๋ชจ๋ธ ์ ์ฅ์ ์ ๊ทผ์ ์ํด HF_TOKEN์ด ํ์ํฉ๋๋ค." | |
| ) | |
| try: | |
| pose_model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, filename="yolo11m-pose.pt", token=token | |
| ) | |
| stgcn_checkpoint = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, filename="best_acc.pth", token=token | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"๋ชจ๋ธ ๋ค์ด๋ก๋ ์คํจ: {e}") from e | |
| return pose_model_path, stgcn_checkpoint | |
| # ----------------------------------------------------------------------------- | |
| # ๋ชจ๋ธ ์ฑ๊ธํค (์ง์ฐ ๋ก๋ฉ) | |
| # ----------------------------------------------------------------------------- | |
| _pose_estimator = None | |
| _stgcn_classifier = None | |
| def get_pose_estimator(): | |
| """PoseEstimator ์ฑ๊ธํค ๋ฐํ""" | |
| global _pose_estimator | |
| if _pose_estimator is None: | |
| from pipeline.models.pose_estimator import PoseEstimator | |
| pose_model_path, _ = download_models() | |
| _pose_estimator = PoseEstimator( | |
| model_path=pose_model_path, | |
| conf_threshold=0.5, | |
| device=str(device) | |
| ) | |
| return _pose_estimator | |
| def get_stgcn_classifier(): | |
| """STGCNClassifier ์ฑ๊ธํค ๋ฐํ""" | |
| global _stgcn_classifier | |
| if _stgcn_classifier is None: | |
| from pipeline.models.stgcn_classifier import STGCNClassifier | |
| _, stgcn_checkpoint = download_models() | |
| _stgcn_classifier = STGCNClassifier( | |
| checkpoint_path=stgcn_checkpoint, | |
| fall_threshold=0.7, | |
| device=str(device) | |
| ) | |
| return _stgcn_classifier | |
| # ----------------------------------------------------------------------------- | |
| # ํ๋ ์ ๋ก๋ (cv2 ์ฌ์ฉ - ๋๋ถ๋ถ์ ๋น๋์ค์์ ๋ ๋น ๋ฆ) | |
| # ----------------------------------------------------------------------------- | |
| def load_video_frames(video_path: str) -> Tuple[np.ndarray, float]: | |
| """ | |
| ๋น๋์ค์์ ์ ์ฒด ํ๋ ์ ๋ก๋ (cv2 ์ฌ์ฉ) | |
| Returns: | |
| frames: (N, H, W, C) numpy array (BGR) | |
| fps: ํ๋ ์ ๋ ์ดํธ | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| return np.array(frames), fps | |
| # ----------------------------------------------------------------------------- | |
| # ๋ฐฐ์น Pose ์ถ๋ก | |
| # ----------------------------------------------------------------------------- | |
| def extract_all_keypoints( | |
| frames: np.ndarray, | |
| pose_estimator, | |
| batch_size: int = 8, | |
| progress_callback=None | |
| ) -> list[Optional[np.ndarray]]: | |
| """ | |
| ์ ์ฒด ํ๋ ์์ ๋ํด ๋ฐฐ์น Pose ์ถ๋ก | |
| Args: | |
| frames: (N, H, W, C) ์ ์ฒด ๋น๋์ค ํ๋ ์ | |
| pose_estimator: PoseEstimator ์ธ์คํด์ค | |
| batch_size: ๋ฐฐ์น ํฌ๊ธฐ | |
| progress_callback: ์งํ๋ฅ ์ฝ๋ฐฑ ํจ์ | |
| Returns: | |
| keypoints_list: [(17, 3) or None, ...] N๊ฐ์ keypoints | |
| """ | |
| n_frames = len(frames) | |
| all_keypoints = [] | |
| for i in range(0, n_frames, batch_size): | |
| batch = list(frames[i:i+batch_size]) | |
| batch_keypoints = pose_estimator.extract_batch(batch) | |
| all_keypoints.extend(batch_keypoints) | |
| if progress_callback: | |
| progress_callback(min(i + batch_size, n_frames), n_frames) | |
| return all_keypoints | |
| # ----------------------------------------------------------------------------- | |
| # ์๋์ฐ ์์ฑ ๋ฐ ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| # ----------------------------------------------------------------------------- | |
| def create_windows_and_predict( | |
| keypoints_list: list[Optional[np.ndarray]], | |
| stgcn_classifier, | |
| window_size: int = 60, | |
| stride: int = 5, | |
| fall_threshold: float = 0.7 | |
| ) -> Tuple[list[int], list[float], Optional[int]]: | |
| """ | |
| keypoints์์ ์๋์ฐ ์์ฑ ํ ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| Args: | |
| keypoints_list: ํ๋ ์๋ณ keypoints ๋ฆฌ์คํธ | |
| stgcn_classifier: STGCNClassifier ์ธ์คํด์ค | |
| window_size: ์๋์ฐ ํฌ๊ธฐ (ํ๋ ์ ์) | |
| stride: ์ถ๋ก ๊ฐ๊ฒฉ (N ํ๋ ์๋ง๋ค 1๋ฒ) | |
| fall_threshold: ๋์ ํ์ ์๊ณ๊ฐ | |
| Returns: | |
| frame_indices: ST-GCN ์์ธก์ด ์๋ ํ๋ ์ ์ธ๋ฑ์ค | |
| fall_probs: ๊ฐ ํ๋ ์์ ๋์ ํ๋ฅ (class 1 ํ๋ฅ ) | |
| first_fall_frame: ์ฒซ ๋์ ๊ฐ์ง ํ๋ ์ ์ธ๋ฑ์ค (์์ผ๋ฉด None) | |
| """ | |
| n_frames = len(keypoints_list) | |
| # None์ ๋น keypoints๋ก ๋์ฒด | |
| processed_keypoints = [] | |
| for kpts in keypoints_list: | |
| if kpts is None: | |
| processed_keypoints.append(np.zeros((17, 3), dtype=np.float32)) | |
| else: | |
| processed_keypoints.append(kpts) | |
| # ์๋์ฐ ์์ฑ (stride ๊ฐ๊ฒฉ์ผ๋ก) | |
| frame_indices = [] | |
| windows = [] | |
| for frame_idx in range(window_size - 1, n_frames, stride): | |
| # ์ด์ window_size ํ๋ ์์ผ๋ก ์๋์ฐ ๊ตฌ์ฑ | |
| window_keypoints = processed_keypoints[frame_idx - window_size + 1:frame_idx + 1] | |
| # (T, V, C) -> (C, T, V, M) ๋ณํ | |
| window = np.array(window_keypoints) # (T=60, V=17, C=3) | |
| window = window.transpose(2, 0, 1) # (C=3, T=60, V=17) | |
| window = np.expand_dims(window, -1) # (C=3, T=60, V=17, M=1) | |
| frame_indices.append(frame_idx) | |
| windows.append(window.astype(np.float32)) | |
| if not windows: | |
| return [], [], None | |
| # ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| predictions, confidences, fall_probs = stgcn_classifier.predict_batch(windows) | |
| # ์ฒซ ๋์ ๊ฐ์ง ํ๋ ์ ์ฐพ๊ธฐ | |
| first_fall_frame = None | |
| for i, (pred, fall_prob) in enumerate(zip(predictions, fall_probs)): | |
| if pred == 1 and fall_prob >= fall_threshold: | |
| first_fall_frame = frame_indices[i] | |
| break | |
| return frame_indices, fall_probs.tolist(), first_fall_frame | |
| # ----------------------------------------------------------------------------- | |
| # ์๊ฐํ ์์ปค ํจ์ (ProcessPoolExecutor์ฉ) | |
| # ----------------------------------------------------------------------------- | |
| # FALL DETECTED ํ ์คํธ ํ์ ์ง์ ์๊ฐ (์ด) | |
| FALL_DISPLAY_DURATION = 2.0 | |
| def _visualize_single_frame(args: tuple) -> Tuple[int, np.ndarray]: | |
| """๋จ์ผ ํ๋ ์ ์๊ฐํ ์์ปค (๊ฐ์ํ๋ ๋ฒ์ )""" | |
| (frame_idx, frame, keypoints, show_fall_text, | |
| viz_keypoints, viz_scale) = args | |
| # ํ๋ก์ ํธ import (์์ปค ํ๋ก์ธ์ค์์) | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) | |
| from pipeline.visualization import visualize_fall_simple | |
| vis_frame = visualize_fall_simple( | |
| frame=frame, | |
| keypoints=keypoints if keypoints is not None and keypoints.sum() > 0 else None, | |
| show_fall_text=show_fall_text, | |
| keypoint_mode=viz_keypoints, | |
| output_scale=viz_scale | |
| ) | |
| return frame_idx, vis_frame | |
| def visualize_clip_parallel( | |
| frames: np.ndarray, | |
| keypoints_list: list[Optional[np.ndarray]], | |
| frame_indices: list[int], | |
| fall_probs: list[float], | |
| clip_start: int, | |
| clip_end: int, | |
| fps: float, | |
| first_fall_frame: Optional[int] = None, | |
| fall_threshold: float = 0.7, | |
| viz_keypoints: str = "all", | |
| viz_scale: float = 1.0, | |
| num_workers: int = 4 | |
| ) -> list[np.ndarray]: | |
| """ | |
| ํด๋ฆฝ ๊ตฌ๊ฐ ๋ณ๋ ฌ ์๊ฐํ (๊ฐ์ํ๋ ๋ฒ์ ) | |
| Args: | |
| frames: ์ ์ฒด ํ๋ ์ | |
| keypoints_list: ์ ์ฒด keypoints | |
| frame_indices: ST-GCN ์์ธก ํ๋ ์ ์ธ๋ฑ์ค | |
| fall_probs: ํ๋ ์๋ณ ๋์ ํ๋ฅ | |
| clip_start: ํด๋ฆฝ ์์ ์ธ๋ฑ์ค | |
| clip_end: ํด๋ฆฝ ์ข ๋ฃ ์ธ๋ฑ์ค | |
| fps: ํ๋ ์ ๋ ์ดํธ | |
| first_fall_frame: ์ฒซ ๋์ ๊ฐ์ง ํ๋ ์ (๊น๋นก์ ๋ฐฉ์ง์ฉ) | |
| fall_threshold: ๋์ ํ์ ์๊ณ๊ฐ | |
| viz_keypoints: ํคํฌ์ธํธ ํ์ ๋ชจ๋ | |
| viz_scale: ์ถ๋ ฅ ์ค์ผ์ผ | |
| num_workers: ๋ณ๋ ฌ ์์ปค ์ | |
| Returns: | |
| vis_frames: ์๊ฐํ๋ ํ๋ ์ ๋ฆฌ์คํธ | |
| """ | |
| # ๊น๋นก์ ๋ฐฉ์ง: ์ฒซ ๋์ ํ N์ด๊ฐ FALL DETECTED ํ์ | |
| fall_display_end_frame = None | |
| if first_fall_frame is not None: | |
| fall_display_end_frame = first_fall_frame + int(fps * FALL_DISPLAY_DURATION) | |
| # ์๊ฐํ ์ธ์ ์ค๋น | |
| viz_args = [] | |
| for i in range(clip_start, clip_end): | |
| frame = frames[i] | |
| keypoints = keypoints_list[i] | |
| # FALL DETECTED ํ ์คํธ ํ์ ์ฌ๋ถ ๊ฒฐ์ (๊น๋นก์ ๋ฐฉ์ง) | |
| show_fall_text = False | |
| if first_fall_frame is not None and fall_display_end_frame is not None: | |
| if first_fall_frame <= i <= fall_display_end_frame: | |
| show_fall_text = True | |
| args = ( | |
| i, # frame_idx | |
| frame, # frame | |
| keypoints, # keypoints | |
| show_fall_text, # show_fall_text (๊น๋นก์ ๋ฐฉ์ง ์ ์ฉ) | |
| viz_keypoints, # viz_keypoints | |
| viz_scale # viz_scale | |
| ) | |
| viz_args.append(args) | |
| # ๋ณ๋ ฌ ์๊ฐํ | |
| with ProcessPoolExecutor(max_workers=num_workers) as executor: | |
| results = list(executor.map(_visualize_single_frame, viz_args)) | |
| # ์์๋๋ก ์ ๋ ฌ | |
| results.sort(key=lambda x: x[0]) | |
| vis_frames = [frame for _, frame in results] | |
| return vis_frames | |
| # ----------------------------------------------------------------------------- | |
| # ํ๋ฅ ๊ทธ๋ํ ์์ฑ | |
| # ----------------------------------------------------------------------------- | |
| def create_probability_graph( | |
| frame_indices: list[int], | |
| fall_probs: list[float], | |
| fall_threshold: float = 0.7, | |
| fps: float = 30.0 | |
| ) -> go.Figure: | |
| """๋์ ํ๋ฅ ๊ทธ๋ํ ์์ฑ (X์ถ: ์๊ฐ)""" | |
| # ํ๋ ์ ์ธ๋ฑ์ค -> ์๊ฐ(์ด) ๋ณํ | |
| time_seconds = [idx / fps for idx in frame_indices] | |
| fig = go.Figure() | |
| # ํ๋ฅ ๋ผ์ธ | |
| fig.add_trace(go.Scatter( | |
| x=time_seconds, | |
| y=fall_probs, | |
| mode='lines', | |
| name='Fall Probability', | |
| line=dict(color='#4682B4', width=2), | |
| fill='tozeroy', | |
| fillcolor='rgba(70, 130, 180, 0.3)' | |
| )) | |
| # ์๊ณ๊ฐ ๋ผ์ธ | |
| fig.add_hline( | |
| y=fall_threshold, | |
| line_dash="dash", | |
| line_color="red", | |
| annotation_text=f"Threshold ({fall_threshold})", | |
| annotation_position="right" | |
| ) | |
| # ๋ ์ด์์ | |
| fig.update_layout( | |
| title="Fall Detection Probability Over Time", | |
| xaxis_title="Time (seconds)", | |
| yaxis_title="Probability", | |
| yaxis=dict(range=[0, 1.05]), | |
| template="plotly_white", | |
| height=300, | |
| margin=dict(l=50, r=50, t=50, b=50), | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="right", | |
| x=1 | |
| ) | |
| ) | |
| return fig | |
| # ----------------------------------------------------------------------------- | |
| # ์ค๋งํธ ํด๋ฆฝ ์ถ์ถ ์ค์ | |
| # ----------------------------------------------------------------------------- | |
| CLIP_PRE_FALL_SECONDS = 1.0 # ๋์ ์ 1์ด | |
| CLIP_POST_FALL_SECONDS = 2.0 # ๋์ ํ 2์ด | |
| # ----------------------------------------------------------------------------- | |
| # ๋ฉ์ธ ์ถ๋ก ํจ์ | |
| # ----------------------------------------------------------------------------- | |
| def process_video( | |
| video_path: str, | |
| fall_threshold: float, | |
| viz_keypoints: str, | |
| progress: gr.Progress = gr.Progress() | |
| ) -> Tuple[Optional[str], Optional[go.Figure], str]: | |
| """ | |
| ๋น๋์ค ์ฒ๋ฆฌ ๋ฐ ๋์ ๊ฐ์ง (๋ฐฐ์น ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ) | |
| Pipeline: | |
| 1. decord๋ก ์ ์ฒด ํ๋ ์ ๋ฐฐ์น ๋ก๋ | |
| 2. YOLO Pose ๋ฐฐ์น ์ถ๋ก โ keypoints ๋์ | |
| 3. ์๋์ฐ ๋จ์ ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| 4. ๋์ ์์ -1s ~ +2s ๊ตฌ๊ฐ๋ง ์๊ฐํ | |
| Args: | |
| video_path: ์ ๋ ฅ ๋น๋์ค ๊ฒฝ๋ก | |
| fall_threshold: ๋์ ํ์ ์๊ณ๊ฐ (0.0-1.0) | |
| viz_keypoints: ํคํฌ์ธํธ ํ์ ๋ชจ๋ ('all' ๋๋ 'major') | |
| progress: Gradio ์งํ๋ฅ ํ์ | |
| Returns: | |
| output_video_path: ๊ฒฐ๊ณผ ํด๋ฆฝ ๊ฒฝ๋ก (๋์ ๊ฐ์ง ์) ๋๋ None | |
| probability_graph: ํ๋ฅ ๊ทธ๋ํ | |
| result_text: ์ต์ข ํ์ ํ ์คํธ | |
| """ | |
| if video_path is None: | |
| return None, None, "๋น๋์ค๋ฅผ ์ ๋ก๋ํด์ฃผ์ธ์." | |
| try: | |
| # Stage 0: ๋ชจ๋ธ ๋ก๋ | |
| progress(0.05, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| pose_estimator = get_pose_estimator() | |
| stgcn_classifier = get_stgcn_classifier() | |
| stgcn_classifier.fall_threshold = fall_threshold | |
| # Stage 1: ํ๋ ์ ๋ก๋ (decord) | |
| progress(0.1, desc="๋น๋์ค ๋ก๋ฉ ์ค...") | |
| frames, fps = load_video_frames(video_path) | |
| n_frames = len(frames) | |
| if n_frames == 0: | |
| return None, None, "๋น๋์ค๋ฅผ ์ฝ์ ์ ์์ต๋๋ค." | |
| # ๋น๋์ค ๊ธธ์ด ๊ฒ์ฆ (120s GPU ํ์์์ ๋๋น) | |
| video_duration = n_frames / fps | |
| if video_duration > 60: | |
| return None, None, ( | |
| f"๋น๋์ค๊ฐ ๋๋ฌด ๊น๋๋ค. " | |
| f"๋น๋์ค ๊ธธ์ด: {video_duration:.1f}์ด (์ ํ: 60์ด). " | |
| f"60์ด ์ด๋ด์ ๋น๋์ค๋ฅผ ์ ๋ก๋ํ์ธ์." | |
| ) | |
| # Stage 2: ๋ฐฐ์น Pose ์ถ๋ก | |
| progress(0.15, desc="Pose ์ถ์ถ ์ค...") | |
| def pose_progress(current, total): | |
| pct = 0.15 + 0.35 * (current / total) | |
| progress(pct, desc=f"Pose ์ถ์ถ ์ค... ({current}/{total})") | |
| keypoints_list = extract_all_keypoints( | |
| frames, pose_estimator, | |
| batch_size=8, | |
| progress_callback=pose_progress | |
| ) | |
| # Stage 3: ST-GCN ๋ฐฐ์น ์ถ๋ก | |
| progress(0.55, desc="๋์ ๋ถ์ ์ค...") | |
| frame_indices, fall_probs, first_fall_frame = create_windows_and_predict( | |
| keypoints_list, | |
| stgcn_classifier, | |
| window_size=60, | |
| stride=5, | |
| fall_threshold=fall_threshold | |
| ) | |
| # ํ๋ฅ ๊ทธ๋ํ ์์ฑ | |
| progress(0.7, desc="๊ทธ๋ํ ์์ฑ ์ค...") | |
| fig = None | |
| if frame_indices and fall_probs: | |
| fig = create_probability_graph(frame_indices, fall_probs, fall_threshold, fps) | |
| # ๋์ ๋ฏธ๊ฐ์ง ์ | |
| if first_fall_frame is None: | |
| progress(1.0, desc="์๋ฃ!") | |
| result_text = ( | |
| f"[Non-Fall] ๋์์ด ๊ฐ์ง๋์ง ์์์ต๋๋ค.\n" | |
| f"๋ถ์ ํ๋ ์: {n_frames}๊ฐ" | |
| ) | |
| return None, fig, result_text | |
| # Stage 4: ๋์ ๊ตฌ๊ฐ๋ง ์๊ฐํ | |
| progress(0.75, desc="ํด๋ฆฝ ์๊ฐํ ์ค...") | |
| pre_fall_frames = int(fps * CLIP_PRE_FALL_SECONDS) | |
| post_fall_frames = int(fps * CLIP_POST_FALL_SECONDS) | |
| clip_start = max(0, first_fall_frame - pre_fall_frames) | |
| clip_end = min(n_frames, first_fall_frame + post_fall_frames) | |
| vis_frames = visualize_clip_parallel( | |
| frames=frames, | |
| keypoints_list=keypoints_list, | |
| frame_indices=frame_indices, | |
| fall_probs=fall_probs, | |
| clip_start=clip_start, | |
| clip_end=clip_end, | |
| fps=fps, | |
| first_fall_frame=first_fall_frame, # ๊น๋นก์ ๋ฐฉ์ง์ฉ | |
| fall_threshold=fall_threshold, | |
| viz_keypoints=viz_keypoints, | |
| viz_scale=1.0, | |
| num_workers=4 | |
| ) | |
| if not vis_frames: | |
| progress(1.0, desc="์๋ฃ!") | |
| return None, fig, "ํด๋ฆฝ ์ถ์ถ์ ์คํจํ์ต๋๋ค." | |
| # Stage 5: ๋น๋์ค ์ธ์ฝ๋ฉ | |
| progress(0.9, desc="ํด๋ฆฝ ์ธ์ฝ๋ฉ ์ค...") | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| output_path = tmp.name | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| clip_height, clip_width = vis_frames[0].shape[:2] | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (clip_width, clip_height)) | |
| for vis_frame in vis_frames: | |
| out.write(vis_frame) | |
| out.release() | |
| # H.264 ์ฌ์ธ์ฝ๋ฉ (๋ธ๋ผ์ฐ์ ํธํ) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| output_h264 = tmp.name | |
| subprocess.run( | |
| [ | |
| 'ffmpeg', '-y', '-i', output_path, | |
| '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', | |
| output_h264, '-loglevel', 'quiet' | |
| ], | |
| check=False, | |
| capture_output=True | |
| ) | |
| # ์์ ํ์ผ ์ ๋ฆฌ | |
| if os.path.exists(output_path): | |
| os.remove(output_path) | |
| final_output = output_h264 if os.path.exists(output_h264) else None | |
| # ์ต์ข ํ์ | |
| progress(1.0, desc="์๋ฃ!") | |
| fall_time = first_fall_frame / fps | |
| clip_duration = len(vis_frames) / fps | |
| result_text = ( | |
| f"[FALL DETECTED] ๋์์ด ๊ฐ์ง๋์์ต๋๋ค!\n" | |
| f"๋์ ์์ : {fall_time:.2f}์ด (ํ๋ ์ #{first_fall_frame})\n" | |
| f"ํด๋ฆฝ ๊ธธ์ด: {clip_duration:.1f}์ด ({len(vis_frames)}ํ๋ ์)" | |
| ) | |
| return final_output, fig, result_text | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}\n{traceback.format_exc()}" | |
| return None, None, error_msg | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| def create_demo() -> gr.Blocks: | |
| """Gradio ๋ฐ๋ชจ ์์ฑ""" | |
| with gr.Blocks(theme=custom_theme, css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # Fall Detection Demo | |
| YOLOv11-Pose + ST-GCN 2-stage ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ์ค์๊ฐ ๋์ ๊ฐ์ง ๋ฐ๋ชจ์ ๋๋ค. | |
| ๋น๋์ค๋ฅผ ์ ๋ก๋ํ๋ฉด ๋์ ์ฌ๋ถ๋ฅผ ๋ถ์ํ๊ณ , ๊ฒฐ๊ณผ ๋น๋์ค์ ํ๋ฅ ๊ทธ๋ํ๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
| **ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ:** | |
| - Stage 1: YOLOv11m-pose (Pose Estimation) - Batch Processing | |
| - Stage 2: ST-GCN (Temporal Classification) - Batch Processing | |
| - Window Size: 60 frames (2s @ 30fps) | |
| """, | |
| elem_id="main-title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # ์ ๋ ฅ ์น์ | |
| gr.Markdown("### ์ ๋ ฅ") | |
| video_input = gr.Video( | |
| label="๋น๋์ค ์ ๋ก๋", | |
| sources=["upload"], | |
| ) | |
| with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): | |
| fall_threshold = gr.Slider( | |
| minimum=0.5, | |
| maximum=0.95, | |
| value=0.7, | |
| step=0.05, | |
| label="๋์ ํ์ ์๊ณ๊ฐ", | |
| info="๊ถ์ฅ: 0.7-0.85" | |
| ) | |
| viz_keypoints = gr.Radio( | |
| choices=["all", "major"], | |
| value="all", | |
| label="ํคํฌ์ธํธ ํ์", | |
| info="all: ์ ์ฒด 17๊ฐ, major: ์ฃผ์ 9๊ฐ" | |
| ) | |
| submit_btn = gr.Button( | |
| "๋ถ์ ์์", | |
| variant="primary", | |
| elem_classes="submit-btn" | |
| ) | |
| with gr.Column(scale=1): | |
| # ์ถ๋ ฅ ์น์ | |
| gr.Markdown("### ๊ฒฐ๊ณผ") | |
| result_text = gr.Textbox( | |
| label="ํ์ ๊ฒฐ๊ณผ", | |
| lines=3, | |
| interactive=False | |
| ) | |
| video_output = gr.Video( | |
| label="๊ฒฐ๊ณผ ๋น๋์ค", | |
| ) | |
| prob_graph = gr.Plot( | |
| label="๋์ ํ๋ฅ ๊ทธ๋ํ", | |
| ) | |
| # ์์ ๋น๋์ค | |
| gr.Markdown("### ์์ ๋น๋์ค") | |
| example_dir = Path(__file__).parent / "examples" | |
| examples = [] | |
| if example_dir.exists(): | |
| for ext in ["*.mp4", "*.avi", "*.mov"]: | |
| examples.extend([str(p) for p in example_dir.glob(ext)]) | |
| if examples: | |
| gr.Examples( | |
| examples=[[ex, 0.7, "all"] for ex in sorted(examples)], | |
| inputs=[video_input, fall_threshold, viz_keypoints], | |
| outputs=[video_output, prob_graph, result_text], | |
| fn=process_video, | |
| cache_examples=False, | |
| examples_per_page=4, | |
| label="์์ ๋น๋์ค", | |
| ) | |
| # ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
| submit_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, fall_threshold, viz_keypoints], | |
| outputs=[video_output, prob_graph, result_text], | |
| ) | |
| # ํธํฐ | |
| gr.Markdown( | |
| """ | |
| --- | |
| **References:** | |
| - [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation | |
| - [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks | |
| - AI Hub Fall Detection Dataset | |
| """ | |
| ) | |
| return demo | |
| # ----------------------------------------------------------------------------- | |
| # ๋ฉ์ธ ์คํ | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True, | |
| ) | |