YoungjaeDev Claude commited on
Commit
dc83241
·
1 Parent(s): b20de68

fix: predict_batch에 batch_size 청킹 추가 - GPU OOM 방지

Browse files

- batch_size 파라미터 추가 (기본값: 32)
- 대량 윈도우 입력 시 배치 청킹으로 메모리 사용 제어
- predict 메서드에 .float() 추가 - dtype 일관성 수정
- CodeRabbit 리뷰 반영

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. models/stgcn_classifier.py +43 -29
models/stgcn_classifier.py CHANGED
@@ -103,7 +103,7 @@ class STGCNClassifier:
103
  window_input = window
104
 
105
  # ST-GCN inference
106
- window_tensor = torch.from_numpy(window_input).unsqueeze(0).to(self.device) # (1, C, T, V, M)
107
 
108
  with torch.no_grad():
109
  outputs = self.model(window_tensor)
@@ -121,6 +121,7 @@ class STGCNClassifier:
121
  def predict_batch(
122
  self,
123
  windows: list[np.ndarray],
 
124
  normalize: bool = True,
125
  debug: bool = False
126
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
@@ -129,6 +130,7 @@ class STGCNClassifier:
129
 
130
  Args:
131
  windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트
 
132
  normalize: hip center 정규화 적용 여부
133
  debug: 디버그 로그 출력 여부
134
 
@@ -140,34 +142,46 @@ class STGCNClassifier:
140
  if not windows:
141
  return np.array([]), np.array([]), np.array([])
142
 
143
- # 정규화 및 배치 텐서 준비
144
- batch_list = []
145
- for window in windows:
146
- if normalize:
147
- window_input = normalize_skeleton(window, method='hip_center')
148
- else:
149
- window_input = window
150
- batch_list.append(torch.from_numpy(window_input).float())
151
-
152
- # 배치 텐서 생성 (N, C, T, V, M)
153
- batch_tensor = torch.stack(batch_list).to(self.device)
154
-
155
- with torch.no_grad():
156
- outputs = self.model(batch_tensor)
157
- probs = torch.softmax(outputs, dim=1)
158
- preds = torch.argmax(outputs, dim=1)
159
-
160
- predictions = preds.cpu().numpy()
161
- # 예측에 대해 해당 클래스의 확률을 신뢰도로 사용
162
- confidences = probs[torch.arange(len(preds)), preds].cpu().numpy()
163
- # Fall 클래스(class 1)의 확률 - 그래프 표시용
164
- fall_probs = probs[:, 1].cpu().numpy()
165
-
166
- if debug:
167
- for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)):
168
- self.logger.debug(f" Batch[{i}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}")
169
-
170
- return predictions, confidences, fall_probs
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def is_fall(self, prediction: int, confidence: float) -> bool:
173
  """
 
103
  window_input = window
104
 
105
  # ST-GCN inference
106
+ window_tensor = torch.from_numpy(window_input).float().unsqueeze(0).to(self.device) # (1, C, T, V, M)
107
 
108
  with torch.no_grad():
109
  outputs = self.model(window_tensor)
 
121
  def predict_batch(
122
  self,
123
  windows: list[np.ndarray],
124
+ batch_size: int = 32,
125
  normalize: bool = True,
126
  debug: bool = False
127
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
 
130
 
131
  Args:
132
  windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트
133
+ batch_size: GPU 배치 크기 (기본값: 32, OOM 방지용)
134
  normalize: hip center 정규화 적용 여부
135
  debug: 디버그 로그 출력 여부
136
 
 
142
  if not windows:
143
  return np.array([]), np.array([]), np.array([])
144
 
145
+ all_predictions = []
146
+ all_confidences = []
147
+ all_fall_probs = []
148
+
149
+ for chunk_start in range(0, len(windows), batch_size):
150
+ chunk_windows = windows[chunk_start:chunk_start + batch_size]
151
+
152
+ batch_list = []
153
+ for window in chunk_windows:
154
+ if normalize:
155
+ window_input = normalize_skeleton(window, method='hip_center')
156
+ else:
157
+ window_input = window
158
+ batch_list.append(torch.from_numpy(window_input).float())
159
+
160
+ batch_tensor = torch.stack(batch_list).to(self.device)
161
+
162
+ with torch.no_grad():
163
+ outputs = self.model(batch_tensor)
164
+ probs = torch.softmax(outputs, dim=1)
165
+ preds = torch.argmax(outputs, dim=1)
166
+
167
+ predictions = preds.cpu().numpy()
168
+ confidences = probs[torch.arange(len(preds)), preds].cpu().numpy()
169
+ fall_probs = probs[:, 1].cpu().numpy()
170
+
171
+ all_predictions.append(predictions)
172
+ all_confidences.append(confidences)
173
+ all_fall_probs.append(fall_probs)
174
+
175
+ if debug:
176
+ for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)):
177
+ global_idx = chunk_start + i
178
+ self.logger.debug(f" Batch[{global_idx}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}")
179
+
180
+ return (
181
+ np.concatenate(all_predictions),
182
+ np.concatenate(all_confidences),
183
+ np.concatenate(all_fall_probs)
184
+ )
185
 
186
  def is_fall(self, prediction: int, confidence: float) -> bool:
187
  """