from typing import Optional, List, Type, TypeVar import aiomysql import logging from .base_repository import BaseRepository T = TypeVar("T") logger = logging.getLogger(__name__) class MySQLRepository(BaseRepository[T]): def __init__(self, entity_class: Type[T]): super().__init__() self.entity_class = entity_class # Get table name defined by the entity class self.table_name = getattr(entity_class, "__tablename__", None) if not self.table_name: raise ValueError( f"Entity class {entity_class.__name__} must define __tablename__" ) async def get_by_id(self, id: int) -> Optional[T]: try: async with self.pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute( f"SELECT * FROM {self.table_name} WHERE id = %s", (id,) ) result = await cursor.fetchone() return self.entity_class.from_dict(result) if result else None except Exception as e: logger.error(f"Database error in get_by_id: {str(e)}") raise # Directly throw the original exception, let the unified error handling handle it async def create(self, entity: T) -> T: with self.db.session() as session: session.add(entity) session.flush() # Ensure ID generation session.refresh(entity) # Refresh the object # Convert to dictionary in session context result = entity.to_dict() return self.model.from_dict(result) # Create a new object instance async def update(self, entity: T) -> T: async with self.pool.acquire() as conn: async with conn.cursor() as cursor: data = entity.to_dict() # Remove id from update data entity_id = data.pop("id") # Don't update create_time data.pop("create_time", None) set_clause = ", ".join([f"{k} = %s" for k in data.keys()]) values = list(data.values()) + [entity_id] query = f"UPDATE {self.table_name} SET {set_clause} WHERE id = %s" await cursor.execute(query, values) await conn.commit() return entity async def delete(self, id: int) -> bool: async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute( f"DELETE FROM {self.table_name} WHERE id = %s", (id,) ) await conn.commit() return cursor.rowcount > 0 async def list( self, filters: dict = None, limit: int = 100, offset: int = 0 ) -> List[T]: async with self.pool.acquire() as conn: async with conn.cursor(aiomysql.DictCursor) as cursor: query = f"SELECT * FROM {self.table_name}" values = [] if filters: where_conditions = [] for key, value in filters.items(): where_conditions.append(f"{key} = %s") values.append(value) if where_conditions: query += " WHERE " + " AND ".join(where_conditions) query += " LIMIT %s OFFSET %s" values.extend([limit, offset]) await cursor.execute(query, values) results = await cursor.fetchall() return [self.entity_class.from_dict(row) for row in results]