Really-amin's picture
Upload 577 files
b190b45 verified
#!/usr/bin/env python3
"""
ML Training Service
===================
سرویس آموزش مدل‌های یادگیری ماشین با قابلیت پیگیری پیشرفت و ذخیره checkpoint
"""
from typing import Optional, List, Dict, Any
from datetime import datetime
from sqlalchemy.orm import Session
from sqlalchemy import and_, desc
import uuid
import logging
import json
from database.models import (
Base, MLTrainingJob, TrainingStep, TrainingStatus
)
logger = logging.getLogger(__name__)
class MLTrainingService:
"""سرویس اصلی آموزش مدل‌های ML"""
def __init__(self, db_session: Session):
"""
Initialize the ML training service.
Args:
db_session: SQLAlchemy database session
"""
self.db = db_session
def start_training(
self,
model_name: str,
training_data_start: datetime,
training_data_end: datetime,
batch_size: int = 32,
learning_rate: Optional[float] = None,
config: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Start training a model.
Args:
model_name: Name of the model to train
training_data_start: Start date for training data
training_data_end: End date for training data
batch_size: Training batch size
learning_rate: Learning rate (optional)
config: Additional training configuration
Returns:
Dict containing training job details
"""
try:
# Generate job ID
job_id = f"TR-{uuid.uuid4().hex[:12].upper()}"
# Create training job
job = MLTrainingJob(
job_id=job_id,
model_name=model_name,
model_version="1.0.0",
status=TrainingStatus.PENDING,
training_data_start=training_data_start,
training_data_end=training_data_end,
batch_size=batch_size,
learning_rate=learning_rate or 0.001,
config=json.dumps(config) if config else None
)
self.db.add(job)
self.db.commit()
self.db.refresh(job)
logger.info(f"Created training job {job_id} for model {model_name}")
# In production, this would start training in background
# For now, we just return the job details
return self._job_to_dict(job)
except Exception as e:
self.db.rollback()
logger.error(f"Error starting training: {e}", exc_info=True)
raise
def execute_training_step(
self,
job_id: str,
step_number: int,
loss: Optional[float] = None,
accuracy: Optional[float] = None,
learning_rate: Optional[float] = None,
metrics: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Execute a single training step.
Args:
job_id: Training job ID
step_number: Step number
loss: Training loss
accuracy: Training accuracy
learning_rate: Current learning rate
metrics: Additional metrics
Returns:
Dict containing step details
"""
try:
# Get training job
job = self.db.query(MLTrainingJob).filter(
MLTrainingJob.job_id == job_id
).first()
if not job:
raise ValueError(f"Training job {job_id} not found")
if job.status != TrainingStatus.RUNNING:
raise ValueError(f"Training job {job_id} is not in RUNNING status")
# Create training step
step = TrainingStep(
job_id=job_id,
step_number=step_number,
loss=loss,
accuracy=accuracy,
learning_rate=learning_rate,
metrics=json.dumps(metrics) if metrics else None
)
self.db.add(step)
# Update job
job.current_step = step_number
if loss is not None:
job.loss = loss
if accuracy is not None:
job.accuracy = accuracy
if learning_rate is not None:
job.learning_rate = learning_rate
self.db.commit()
self.db.refresh(step)
logger.info(f"Training step {step_number} executed for job {job_id}")
return self._step_to_dict(step)
except Exception as e:
self.db.rollback()
logger.error(f"Error executing training step: {e}", exc_info=True)
raise
def get_training_status(self, job_id: str) -> Dict[str, Any]:
"""
Get the current training status.
Args:
job_id: Training job ID
Returns:
Dict containing training status
"""
try:
job = self.db.query(MLTrainingJob).filter(
MLTrainingJob.job_id == job_id
).first()
if not job:
raise ValueError(f"Training job {job_id} not found")
return self._job_to_dict(job)
except Exception as e:
logger.error(f"Error getting training status: {e}", exc_info=True)
raise
def get_training_history(
self,
model_name: Optional[str] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
Get training history.
Args:
model_name: Filter by model name (optional)
limit: Maximum number of jobs to return
Returns:
List of training job dictionaries
"""
try:
query = self.db.query(MLTrainingJob)
if model_name:
query = query.filter(MLTrainingJob.model_name == model_name)
jobs = query.order_by(desc(MLTrainingJob.created_at)).limit(limit).all()
return [self._job_to_dict(job) for job in jobs]
except Exception as e:
logger.error(f"Error retrieving training history: {e}", exc_info=True)
raise
def update_training_status(
self,
job_id: str,
status: str,
checkpoint_path: Optional[str] = None,
error_message: Optional[str] = None
) -> Dict[str, Any]:
"""
Update training job status.
Args:
job_id: Training job ID
status: New status
checkpoint_path: Path to checkpoint (optional)
error_message: Error message if failed (optional)
Returns:
Dict containing updated job details
"""
try:
job = self.db.query(MLTrainingJob).filter(
MLTrainingJob.job_id == job_id
).first()
if not job:
raise ValueError(f"Training job {job_id} not found")
job.status = TrainingStatus[status.upper()]
if status.upper() == "RUNNING" and not job.started_at:
job.started_at = datetime.utcnow()
if status.upper() in ["COMPLETED", "FAILED", "CANCELLED"]:
job.completed_at = datetime.utcnow()
if checkpoint_path:
job.checkpoint_path = checkpoint_path
if error_message:
job.error_message = error_message
self.db.commit()
self.db.refresh(job)
return self._job_to_dict(job)
except Exception as e:
self.db.rollback()
logger.error(f"Error updating training status: {e}", exc_info=True)
raise
def _job_to_dict(self, job: MLTrainingJob) -> Dict[str, Any]:
"""Convert job model to dictionary."""
config = json.loads(job.config) if job.config else {}
return {
"job_id": job.job_id,
"model_name": job.model_name,
"model_version": job.model_version,
"status": job.status.value if job.status else None,
"training_data_start": job.training_data_start.isoformat() if job.training_data_start else None,
"training_data_end": job.training_data_end.isoformat() if job.training_data_end else None,
"total_steps": job.total_steps,
"current_step": job.current_step,
"batch_size": job.batch_size,
"learning_rate": job.learning_rate,
"loss": job.loss,
"accuracy": job.accuracy,
"checkpoint_path": job.checkpoint_path,
"config": config,
"error_message": job.error_message,
"created_at": job.created_at.isoformat() if job.created_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
"updated_at": job.updated_at.isoformat() if job.updated_at else None
}
def _step_to_dict(self, step: TrainingStep) -> Dict[str, Any]:
"""Convert step model to dictionary."""
metrics = json.loads(step.metrics) if step.metrics else {}
return {
"id": step.id,
"job_id": step.job_id,
"step_number": step.step_number,
"loss": step.loss,
"accuracy": step.accuracy,
"learning_rate": step.learning_rate,
"metrics": metrics,
"timestamp": step.timestamp.isoformat() if step.timestamp else None
}