|
|
|
|
|
"""
|
|
|
ML Training Service
|
|
|
===================
|
|
|
سرویس آموزش مدلهای یادگیری ماشین با قابلیت پیگیری پیشرفت و ذخیره checkpoint
|
|
|
"""
|
|
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
|
from datetime import datetime
|
|
|
from sqlalchemy.orm import Session
|
|
|
from sqlalchemy import and_, desc
|
|
|
import uuid
|
|
|
import logging
|
|
|
import json
|
|
|
|
|
|
from database.models import (
|
|
|
Base, MLTrainingJob, TrainingStep, TrainingStatus
|
|
|
)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class MLTrainingService:
|
|
|
"""سرویس اصلی آموزش مدلهای ML"""
|
|
|
|
|
|
def __init__(self, db_session: Session):
|
|
|
"""
|
|
|
Initialize the ML training service.
|
|
|
|
|
|
Args:
|
|
|
db_session: SQLAlchemy database session
|
|
|
"""
|
|
|
self.db = db_session
|
|
|
|
|
|
def start_training(
|
|
|
self,
|
|
|
model_name: str,
|
|
|
training_data_start: datetime,
|
|
|
training_data_end: datetime,
|
|
|
batch_size: int = 32,
|
|
|
learning_rate: Optional[float] = None,
|
|
|
config: Optional[Dict[str, Any]] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Start training a model.
|
|
|
|
|
|
Args:
|
|
|
model_name: Name of the model to train
|
|
|
training_data_start: Start date for training data
|
|
|
training_data_end: End date for training data
|
|
|
batch_size: Training batch size
|
|
|
learning_rate: Learning rate (optional)
|
|
|
config: Additional training configuration
|
|
|
|
|
|
Returns:
|
|
|
Dict containing training job details
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
job_id = f"TR-{uuid.uuid4().hex[:12].upper()}"
|
|
|
|
|
|
|
|
|
job = MLTrainingJob(
|
|
|
job_id=job_id,
|
|
|
model_name=model_name,
|
|
|
model_version="1.0.0",
|
|
|
status=TrainingStatus.PENDING,
|
|
|
training_data_start=training_data_start,
|
|
|
training_data_end=training_data_end,
|
|
|
batch_size=batch_size,
|
|
|
learning_rate=learning_rate or 0.001,
|
|
|
config=json.dumps(config) if config else None
|
|
|
)
|
|
|
|
|
|
self.db.add(job)
|
|
|
self.db.commit()
|
|
|
self.db.refresh(job)
|
|
|
|
|
|
logger.info(f"Created training job {job_id} for model {model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
return self._job_to_dict(job)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"Error starting training: {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def execute_training_step(
|
|
|
self,
|
|
|
job_id: str,
|
|
|
step_number: int,
|
|
|
loss: Optional[float] = None,
|
|
|
accuracy: Optional[float] = None,
|
|
|
learning_rate: Optional[float] = None,
|
|
|
metrics: Optional[Dict[str, Any]] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Execute a single training step.
|
|
|
|
|
|
Args:
|
|
|
job_id: Training job ID
|
|
|
step_number: Step number
|
|
|
loss: Training loss
|
|
|
accuracy: Training accuracy
|
|
|
learning_rate: Current learning rate
|
|
|
metrics: Additional metrics
|
|
|
|
|
|
Returns:
|
|
|
Dict containing step details
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
job = self.db.query(MLTrainingJob).filter(
|
|
|
MLTrainingJob.job_id == job_id
|
|
|
).first()
|
|
|
|
|
|
if not job:
|
|
|
raise ValueError(f"Training job {job_id} not found")
|
|
|
|
|
|
if job.status != TrainingStatus.RUNNING:
|
|
|
raise ValueError(f"Training job {job_id} is not in RUNNING status")
|
|
|
|
|
|
|
|
|
step = TrainingStep(
|
|
|
job_id=job_id,
|
|
|
step_number=step_number,
|
|
|
loss=loss,
|
|
|
accuracy=accuracy,
|
|
|
learning_rate=learning_rate,
|
|
|
metrics=json.dumps(metrics) if metrics else None
|
|
|
)
|
|
|
|
|
|
self.db.add(step)
|
|
|
|
|
|
|
|
|
job.current_step = step_number
|
|
|
if loss is not None:
|
|
|
job.loss = loss
|
|
|
if accuracy is not None:
|
|
|
job.accuracy = accuracy
|
|
|
if learning_rate is not None:
|
|
|
job.learning_rate = learning_rate
|
|
|
|
|
|
self.db.commit()
|
|
|
self.db.refresh(step)
|
|
|
|
|
|
logger.info(f"Training step {step_number} executed for job {job_id}")
|
|
|
|
|
|
return self._step_to_dict(step)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"Error executing training step: {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def get_training_status(self, job_id: str) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Get the current training status.
|
|
|
|
|
|
Args:
|
|
|
job_id: Training job ID
|
|
|
|
|
|
Returns:
|
|
|
Dict containing training status
|
|
|
"""
|
|
|
try:
|
|
|
job = self.db.query(MLTrainingJob).filter(
|
|
|
MLTrainingJob.job_id == job_id
|
|
|
).first()
|
|
|
|
|
|
if not job:
|
|
|
raise ValueError(f"Training job {job_id} not found")
|
|
|
|
|
|
return self._job_to_dict(job)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error getting training status: {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def get_training_history(
|
|
|
self,
|
|
|
model_name: Optional[str] = None,
|
|
|
limit: int = 100
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
Get training history.
|
|
|
|
|
|
Args:
|
|
|
model_name: Filter by model name (optional)
|
|
|
limit: Maximum number of jobs to return
|
|
|
|
|
|
Returns:
|
|
|
List of training job dictionaries
|
|
|
"""
|
|
|
try:
|
|
|
query = self.db.query(MLTrainingJob)
|
|
|
|
|
|
if model_name:
|
|
|
query = query.filter(MLTrainingJob.model_name == model_name)
|
|
|
|
|
|
jobs = query.order_by(desc(MLTrainingJob.created_at)).limit(limit).all()
|
|
|
|
|
|
return [self._job_to_dict(job) for job in jobs]
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error retrieving training history: {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def update_training_status(
|
|
|
self,
|
|
|
job_id: str,
|
|
|
status: str,
|
|
|
checkpoint_path: Optional[str] = None,
|
|
|
error_message: Optional[str] = None
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Update training job status.
|
|
|
|
|
|
Args:
|
|
|
job_id: Training job ID
|
|
|
status: New status
|
|
|
checkpoint_path: Path to checkpoint (optional)
|
|
|
error_message: Error message if failed (optional)
|
|
|
|
|
|
Returns:
|
|
|
Dict containing updated job details
|
|
|
"""
|
|
|
try:
|
|
|
job = self.db.query(MLTrainingJob).filter(
|
|
|
MLTrainingJob.job_id == job_id
|
|
|
).first()
|
|
|
|
|
|
if not job:
|
|
|
raise ValueError(f"Training job {job_id} not found")
|
|
|
|
|
|
job.status = TrainingStatus[status.upper()]
|
|
|
|
|
|
if status.upper() == "RUNNING" and not job.started_at:
|
|
|
job.started_at = datetime.utcnow()
|
|
|
|
|
|
if status.upper() in ["COMPLETED", "FAILED", "CANCELLED"]:
|
|
|
job.completed_at = datetime.utcnow()
|
|
|
|
|
|
if checkpoint_path:
|
|
|
job.checkpoint_path = checkpoint_path
|
|
|
|
|
|
if error_message:
|
|
|
job.error_message = error_message
|
|
|
|
|
|
self.db.commit()
|
|
|
self.db.refresh(job)
|
|
|
|
|
|
return self._job_to_dict(job)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.db.rollback()
|
|
|
logger.error(f"Error updating training status: {e}", exc_info=True)
|
|
|
raise
|
|
|
|
|
|
def _job_to_dict(self, job: MLTrainingJob) -> Dict[str, Any]:
|
|
|
"""Convert job model to dictionary."""
|
|
|
config = json.loads(job.config) if job.config else {}
|
|
|
|
|
|
return {
|
|
|
"job_id": job.job_id,
|
|
|
"model_name": job.model_name,
|
|
|
"model_version": job.model_version,
|
|
|
"status": job.status.value if job.status else None,
|
|
|
"training_data_start": job.training_data_start.isoformat() if job.training_data_start else None,
|
|
|
"training_data_end": job.training_data_end.isoformat() if job.training_data_end else None,
|
|
|
"total_steps": job.total_steps,
|
|
|
"current_step": job.current_step,
|
|
|
"batch_size": job.batch_size,
|
|
|
"learning_rate": job.learning_rate,
|
|
|
"loss": job.loss,
|
|
|
"accuracy": job.accuracy,
|
|
|
"checkpoint_path": job.checkpoint_path,
|
|
|
"config": config,
|
|
|
"error_message": job.error_message,
|
|
|
"created_at": job.created_at.isoformat() if job.created_at else None,
|
|
|
"started_at": job.started_at.isoformat() if job.started_at else None,
|
|
|
"completed_at": job.completed_at.isoformat() if job.completed_at else None,
|
|
|
"updated_at": job.updated_at.isoformat() if job.updated_at else None
|
|
|
}
|
|
|
|
|
|
def _step_to_dict(self, step: TrainingStep) -> Dict[str, Any]:
|
|
|
"""Convert step model to dictionary."""
|
|
|
metrics = json.loads(step.metrics) if step.metrics else {}
|
|
|
|
|
|
return {
|
|
|
"id": step.id,
|
|
|
"job_id": step.job_id,
|
|
|
"step_number": step.step_number,
|
|
|
"loss": step.loss,
|
|
|
"accuracy": step.accuracy,
|
|
|
"learning_rate": step.learning_rate,
|
|
|
"metrics": metrics,
|
|
|
"timestamp": step.timestamp.isoformat() if step.timestamp else None
|
|
|
}
|
|
|
|
|
|
|