Spaces:
Sleeping
Sleeping
File size: 3,679 Bytes
01d5a5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
from typing import Optional, List, Type, TypeVar
import aiomysql
import logging
from .base_repository import BaseRepository
T = TypeVar("T")
logger = logging.getLogger(__name__)
class MySQLRepository(BaseRepository[T]):
def __init__(self, entity_class: Type[T]):
super().__init__()
self.entity_class = entity_class
# Get table name defined by the entity class
self.table_name = getattr(entity_class, "__tablename__", None)
if not self.table_name:
raise ValueError(
f"Entity class {entity_class.__name__} must define __tablename__"
)
async def get_by_id(self, id: int) -> Optional[T]:
try:
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(
f"SELECT * FROM {self.table_name} WHERE id = %s", (id,)
)
result = await cursor.fetchone()
return self.entity_class.from_dict(result) if result else None
except Exception as e:
logger.error(f"Database error in get_by_id: {str(e)}")
raise # Directly throw the original exception, let the unified error handling handle it
async def create(self, entity: T) -> T:
with self.db.session() as session:
session.add(entity)
session.flush() # Ensure ID generation
session.refresh(entity) # Refresh the object
# Convert to dictionary in session context
result = entity.to_dict()
return self.model.from_dict(result) # Create a new object instance
async def update(self, entity: T) -> T:
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
data = entity.to_dict()
# Remove id from update data
entity_id = data.pop("id")
# Don't update create_time
data.pop("create_time", None)
set_clause = ", ".join([f"{k} = %s" for k in data.keys()])
values = list(data.values()) + [entity_id]
query = f"UPDATE {self.table_name} SET {set_clause} WHERE id = %s"
await cursor.execute(query, values)
await conn.commit()
return entity
async def delete(self, id: int) -> bool:
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
f"DELETE FROM {self.table_name} WHERE id = %s", (id,)
)
await conn.commit()
return cursor.rowcount > 0
async def list(
self, filters: dict = None, limit: int = 100, offset: int = 0
) -> List[T]:
async with self.pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
query = f"SELECT * FROM {self.table_name}"
values = []
if filters:
where_conditions = []
for key, value in filters.items():
where_conditions.append(f"{key} = %s")
values.append(value)
if where_conditions:
query += " WHERE " + " AND ".join(where_conditions)
query += " LIMIT %s OFFSET %s"
values.extend([limit, offset])
await cursor.execute(query, values)
results = await cursor.fetchall()
return [self.entity_class.from_dict(row) for row in results]
|