File size: 3,679 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Optional, List, Type, TypeVar
import aiomysql
import logging

from .base_repository import BaseRepository

T = TypeVar("T")

logger = logging.getLogger(__name__)


class MySQLRepository(BaseRepository[T]):
    def __init__(self, entity_class: Type[T]):
        super().__init__()
        self.entity_class = entity_class
        # Get table name defined by the entity class
        self.table_name = getattr(entity_class, "__tablename__", None)
        if not self.table_name:
            raise ValueError(
                f"Entity class {entity_class.__name__} must define __tablename__"
            )

    async def get_by_id(self, id: int) -> Optional[T]:
        try:
            async with self.pool.acquire() as conn:
                async with conn.cursor(aiomysql.DictCursor) as cursor:
                    await cursor.execute(
                        f"SELECT * FROM {self.table_name} WHERE id = %s", (id,)
                    )
                    result = await cursor.fetchone()
                    return self.entity_class.from_dict(result) if result else None
        except Exception as e:
            logger.error(f"Database error in get_by_id: {str(e)}")
            raise  # Directly throw the original exception, let the unified error handling handle it

    async def create(self, entity: T) -> T:
        with self.db.session() as session:
            session.add(entity)
            session.flush()  # Ensure ID generation
            session.refresh(entity)  # Refresh the object
            # Convert to dictionary in session context
            result = entity.to_dict()
            return self.model.from_dict(result)  # Create a new object instance

    async def update(self, entity: T) -> T:
        async with self.pool.acquire() as conn:
            async with conn.cursor() as cursor:
                data = entity.to_dict()
                # Remove id from update data
                entity_id = data.pop("id")
                # Don't update create_time
                data.pop("create_time", None)

                set_clause = ", ".join([f"{k} = %s" for k in data.keys()])
                values = list(data.values()) + [entity_id]

                query = f"UPDATE {self.table_name} SET {set_clause} WHERE id = %s"
                await cursor.execute(query, values)
                await conn.commit()
                return entity

    async def delete(self, id: int) -> bool:
        async with self.pool.acquire() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute(
                    f"DELETE FROM {self.table_name} WHERE id = %s", (id,)
                )
                await conn.commit()
                return cursor.rowcount > 0

    async def list(
        self, filters: dict = None, limit: int = 100, offset: int = 0
    ) -> List[T]:
        async with self.pool.acquire() as conn:
            async with conn.cursor(aiomysql.DictCursor) as cursor:
                query = f"SELECT * FROM {self.table_name}"
                values = []

                if filters:
                    where_conditions = []
                    for key, value in filters.items():
                        where_conditions.append(f"{key} = %s")
                        values.append(value)
                    if where_conditions:
                        query += " WHERE " + " AND ".join(where_conditions)

                query += " LIMIT %s OFFSET %s"
                values.extend([limit, offset])

                await cursor.execute(query, values)
                results = await cursor.fetchall()
                return [self.entity_class.from_dict(row) for row in results]