fahmiaziz98 commited on
Commit
d57816a
·
1 Parent(s): 58daf34

validate model type

Browse files
Files changed (1) hide show
  1. src/api/routers/embedding.py +77 -74
src/api/routers/embedding.py CHANGED
@@ -32,6 +32,30 @@ from src.config.settings import get_settings
32
  router = APIRouter(tags=["embeddings"])
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @router.post(
36
  "/embeddings",
37
  response_model=DenseEmbedResponse,
@@ -46,13 +70,10 @@ async def create_embeddings_document(
46
  """
47
  Generate embeddings for multiple texts.
48
 
49
- Args:
50
- request: BatchEmbedRequest with input, model, and optional parameters
51
- manager: Model manager dependency
52
- settings: Application settings
53
 
54
- Returns:
55
- DenseEmbedResponse
56
  Raises:
57
  HTTPException: On validation or generation errors
58
  """
@@ -66,43 +87,35 @@ async def create_embeddings_document(
66
  kwargs = extract_embedding_kwargs(request)
67
 
68
  model = manager.get_model(request.model)
69
- config = manager.model_configs[request.model]
 
 
70
 
71
  start_time = time.time()
72
 
73
- if config.type == "embeddings":
74
- embeddings = model.embed(
75
- input=request.input, **kwargs
76
- )
77
- processing_time = time.time() - start_time
78
-
79
- data = []
80
- for idx, embedding in enumerate(embeddings):
81
- data.append(
82
- EmbeddingObject(
83
- object="embedding",
84
- embedding=embedding,
85
- index=idx,
86
- )
87
- )
88
-
89
- # Calculate token usage
90
- token_usage = TokenUsage(
91
- prompt_tokens=count_tokens_batch(request.input),
92
- total_tokens=count_tokens_batch(request.input),
93
- )
94
 
95
- response = DenseEmbedResponse(
96
- object="list",
97
- data=data,
98
- model=request.model,
99
- usage=token_usage,
100
- )
101
- else:
102
- raise HTTPException(
103
- status_code=status.HTTP_400_BAD_REQUEST,
104
- detail=f"Model '{request.model}' is not a dense model. Type: {config.type}",
105
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  logger.info(
108
  f"Generated {len(request.input)} embeddings "
@@ -138,12 +151,9 @@ async def create_sparse_embedding(
138
  """
139
  Generate a single/batch sparse embedding.
140
 
141
- Args:
142
- request: EmbedRequest with input, model, and optional parameters
143
- manager: Model manager dependency
144
-
145
- Returns:
146
- SparseEmbedResponse
147
 
148
  Raises:
149
  HTTPException: On validation or generation errors
@@ -153,41 +163,33 @@ async def create_sparse_embedding(
153
  kwargs = extract_embedding_kwargs(request)
154
 
155
  model = manager.get_model(request.model)
156
- config = manager.model_configs[request.model]
 
 
157
 
158
  start_time = time.time()
159
 
160
- if config.type == "sparse-embeddings":
161
- sparse_results = model.embed(
162
- input=request.input, **kwargs
163
- )
164
- processing_time = time.time() - start_time
165
-
166
- sparse_embeddings = []
167
- for idx, sparse_result in enumerate(sparse_results):
168
- sparse_embeddings.append(
169
- SparseEmbedding(
170
- text=request.input[idx],
171
- indices=sparse_result["indices"],
172
- values=sparse_result["values"],
173
- )
174
- )
175
-
176
- response = SparseEmbedResponse(
177
- embeddings=sparse_embeddings,
178
- count=len(sparse_embeddings),
179
- model=request.model
180
- )
181
-
182
- else:
183
- raise HTTPException(
184
- status_code=status.HTTP_400_BAD_REQUEST,
185
- detail=f"Model '{request.model}' is not a sparse model. Type: {config.type}",
186
  )
 
 
 
 
 
 
 
 
187
 
188
  logger.info(
189
- f"Generated {len(request.texts)} embeddings "
190
- f"in {processing_time:.3f}s ({len(request.texts) / processing_time:.1f} texts/s)"
191
  )
192
 
193
  return response
@@ -199,8 +201,9 @@ async def create_sparse_embedding(
199
  except EmbeddingGenerationError as e:
200
  raise HTTPException(status_code=e.status_code, detail=e.message)
201
  except Exception as e:
202
- logger.exception("Unexpected error in create_query_embedding")
203
  raise HTTPException(
204
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
205
  detail=f"Failed to create query embedding: {str(e)}",
206
  )
 
 
32
  router = APIRouter(tags=["embeddings"])
33
 
34
 
35
+ def _ensure_model_type(
36
+ config, expected_type: str, model_id: str
37
+ ) -> None:
38
+ """
39
+ Validate that the model configuration matches the expected type.
40
+
41
+ Raises:
42
+ HTTPException: If the model is missing or the type does not match.
43
+ """
44
+ if config is None:
45
+ raise HTTPException(
46
+ status_code=status.HTTP_404_NOT_FOUND,
47
+ detail=f"Model '{model_id}' not found.",
48
+ )
49
+ if config.type != expected_type:
50
+ raise HTTPException(
51
+ status_code=status.HTTP_400_BAD_REQUEST,
52
+ detail=(
53
+ f"Model '{model_id}' is not a {expected_type.replace('-', ' ')} "
54
+ f"model. Detected type: {config.type}"
55
+ ),
56
+ )
57
+
58
+
59
  @router.post(
60
  "/embeddings",
61
  response_model=DenseEmbedResponse,
 
70
  """
71
  Generate embeddings for multiple texts.
72
 
73
+ The endpoint validates the request, checks that the requested
74
+ model is a dense embedding model, and returns a
75
+ :class:`DenseEmbedResponse`.
 
76
 
 
 
77
  Raises:
78
  HTTPException: On validation or generation errors
79
  """
 
87
  kwargs = extract_embedding_kwargs(request)
88
 
89
  model = manager.get_model(request.model)
90
+ config = manager.model_configs.get(request.model)
91
+
92
+ _ensure_model_type(config, "embeddings", request.model)
93
 
94
  start_time = time.time()
95
 
96
+ embeddings = model.embed(input=request.input, **kwargs)
97
+ processing_time = time.time() - start_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ data = [
100
+ EmbeddingObject(
101
+ object="embedding",
102
+ embedding=embedding,
103
+ index=idx,
 
 
 
 
 
104
  )
105
+ for idx, embedding in enumerate(embeddings)
106
+ ]
107
+
108
+ token_usage = TokenUsage(
109
+ prompt_tokens=count_tokens_batch(request.input),
110
+ total_tokens=count_tokens_batch(request.input),
111
+ )
112
+
113
+ response = DenseEmbedResponse(
114
+ object="list",
115
+ data=data,
116
+ model=request.model,
117
+ usage=token_usage,
118
+ )
119
 
120
  logger.info(
121
  f"Generated {len(request.input)} embeddings "
 
151
  """
152
  Generate a single/batch sparse embedding.
153
 
154
+ The endpoint validates the request, checks that the requested
155
+ model is a sparse embedding model, and returns a
156
+ :class:`SparseEmbedResponse`.
 
 
 
157
 
158
  Raises:
159
  HTTPException: On validation or generation errors
 
163
  kwargs = extract_embedding_kwargs(request)
164
 
165
  model = manager.get_model(request.model)
166
+ config = manager.model_configs.get(request.model)
167
+
168
+ _ensure_model_type(config, "sparse-embeddings", request.model)
169
 
170
  start_time = time.time()
171
 
172
+ sparse_results = model.embed(input=request.input, **kwargs)
173
+ processing_time = time.time() - start_time
174
+
175
+ sparse_embeddings = [
176
+ SparseEmbedding(
177
+ text=request.input[idx],
178
+ indices=sparse_result["indices"],
179
+ values=sparse_result["values"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  )
181
+ for idx, sparse_result in enumerate(sparse_results)
182
+ ]
183
+
184
+ response = SparseEmbedResponse(
185
+ embeddings=sparse_embeddings,
186
+ count=len(sparse_embeddings),
187
+ model=request.model,
188
+ )
189
 
190
  logger.info(
191
+ f"Generated {len(request.input)} embeddings "
192
+ f"in {processing_time:.3f}s ({len(request.input) / processing_time:.1f} texts/s)"
193
  )
194
 
195
  return response
 
201
  except EmbeddingGenerationError as e:
202
  raise HTTPException(status_code=e.status_code, detail=e.message)
203
  except Exception as e:
204
+ logger.exception("Unexpected error in create_sparse_embedding")
205
  raise HTTPException(
206
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
207
  detail=f"Failed to create query embedding: {str(e)}",
208
  )
209
+