Spaces:
Sleeping
Sleeping
YoungjaeDev
Claude
commited on
Commit
·
dc83241
1
Parent(s):
b20de68
fix: predict_batch에 batch_size 청킹 추가 - GPU OOM 방지
Browse files- batch_size 파라미터 추가 (기본값: 32)
- 대량 윈도우 입력 시 배치 청킹으로 메모리 사용 제어
- predict 메서드에 .float() 추가 - dtype 일관성 수정
- CodeRabbit 리뷰 반영
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- models/stgcn_classifier.py +43 -29
models/stgcn_classifier.py
CHANGED
|
@@ -103,7 +103,7 @@ class STGCNClassifier:
|
|
| 103 |
window_input = window
|
| 104 |
|
| 105 |
# ST-GCN inference
|
| 106 |
-
window_tensor = torch.from_numpy(window_input).unsqueeze(0).to(self.device) # (1, C, T, V, M)
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
outputs = self.model(window_tensor)
|
|
@@ -121,6 +121,7 @@ class STGCNClassifier:
|
|
| 121 |
def predict_batch(
|
| 122 |
self,
|
| 123 |
windows: list[np.ndarray],
|
|
|
|
| 124 |
normalize: bool = True,
|
| 125 |
debug: bool = False
|
| 126 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
@@ -129,6 +130,7 @@ class STGCNClassifier:
|
|
| 129 |
|
| 130 |
Args:
|
| 131 |
windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트
|
|
|
|
| 132 |
normalize: hip center 정규화 적용 여부
|
| 133 |
debug: 디버그 로그 출력 여부
|
| 134 |
|
|
@@ -140,34 +142,46 @@ class STGCNClassifier:
|
|
| 140 |
if not windows:
|
| 141 |
return np.array([]), np.array([]), np.array([])
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
batch_list
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
def is_fall(self, prediction: int, confidence: float) -> bool:
|
| 173 |
"""
|
|
|
|
| 103 |
window_input = window
|
| 104 |
|
| 105 |
# ST-GCN inference
|
| 106 |
+
window_tensor = torch.from_numpy(window_input).float().unsqueeze(0).to(self.device) # (1, C, T, V, M)
|
| 107 |
|
| 108 |
with torch.no_grad():
|
| 109 |
outputs = self.model(window_tensor)
|
|
|
|
| 121 |
def predict_batch(
|
| 122 |
self,
|
| 123 |
windows: list[np.ndarray],
|
| 124 |
+
batch_size: int = 32,
|
| 125 |
normalize: bool = True,
|
| 126 |
debug: bool = False
|
| 127 |
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
|
|
| 130 |
|
| 131 |
Args:
|
| 132 |
windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트
|
| 133 |
+
batch_size: GPU 배치 크기 (기본값: 32, OOM 방지용)
|
| 134 |
normalize: hip center 정규화 적용 여부
|
| 135 |
debug: 디버그 로그 출력 여부
|
| 136 |
|
|
|
|
| 142 |
if not windows:
|
| 143 |
return np.array([]), np.array([]), np.array([])
|
| 144 |
|
| 145 |
+
all_predictions = []
|
| 146 |
+
all_confidences = []
|
| 147 |
+
all_fall_probs = []
|
| 148 |
+
|
| 149 |
+
for chunk_start in range(0, len(windows), batch_size):
|
| 150 |
+
chunk_windows = windows[chunk_start:chunk_start + batch_size]
|
| 151 |
+
|
| 152 |
+
batch_list = []
|
| 153 |
+
for window in chunk_windows:
|
| 154 |
+
if normalize:
|
| 155 |
+
window_input = normalize_skeleton(window, method='hip_center')
|
| 156 |
+
else:
|
| 157 |
+
window_input = window
|
| 158 |
+
batch_list.append(torch.from_numpy(window_input).float())
|
| 159 |
+
|
| 160 |
+
batch_tensor = torch.stack(batch_list).to(self.device)
|
| 161 |
+
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
outputs = self.model(batch_tensor)
|
| 164 |
+
probs = torch.softmax(outputs, dim=1)
|
| 165 |
+
preds = torch.argmax(outputs, dim=1)
|
| 166 |
+
|
| 167 |
+
predictions = preds.cpu().numpy()
|
| 168 |
+
confidences = probs[torch.arange(len(preds)), preds].cpu().numpy()
|
| 169 |
+
fall_probs = probs[:, 1].cpu().numpy()
|
| 170 |
+
|
| 171 |
+
all_predictions.append(predictions)
|
| 172 |
+
all_confidences.append(confidences)
|
| 173 |
+
all_fall_probs.append(fall_probs)
|
| 174 |
+
|
| 175 |
+
if debug:
|
| 176 |
+
for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)):
|
| 177 |
+
global_idx = chunk_start + i
|
| 178 |
+
self.logger.debug(f" Batch[{global_idx}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}")
|
| 179 |
+
|
| 180 |
+
return (
|
| 181 |
+
np.concatenate(all_predictions),
|
| 182 |
+
np.concatenate(all_confidences),
|
| 183 |
+
np.concatenate(all_fall_probs)
|
| 184 |
+
)
|
| 185 |
|
| 186 |
def is_fall(self, prediction: int, confidence: float) -> bool:
|
| 187 |
"""
|