Spaces:
Running
Running
debug model addition to registry
Browse files- app/routers/models.py +12 -11
- tests/test_openai_compat.py +1 -1
app/routers/models.py
CHANGED
|
@@ -9,22 +9,23 @@ from ..core.model_registry import ModelSpec, get_model_spec, list_models
|
|
| 9 |
router = APIRouter(prefix="/v1", tags=["models"])
|
| 10 |
|
| 11 |
|
| 12 |
-
def _serialize_model(spec: ModelSpec) -> dict:
|
| 13 |
payload = {
|
| 14 |
"id": spec.name,
|
| 15 |
"object": "model",
|
| 16 |
"owned_by": "owner",
|
| 17 |
"permission": [],
|
| 18 |
}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
return payload
|
| 29 |
|
| 30 |
|
|
@@ -40,4 +41,4 @@ def retrieve_model(model_id: str) -> dict:
|
|
| 40 |
spec = get_model_spec(model_id)
|
| 41 |
except KeyError:
|
| 42 |
raise model_not_found(model_id)
|
| 43 |
-
return _serialize_model(spec)
|
|
|
|
| 9 |
router = APIRouter(prefix="/v1", tags=["models"])
|
| 10 |
|
| 11 |
|
| 12 |
+
def _serialize_model(spec: ModelSpec, include_metadata: bool = False) -> dict:
|
| 13 |
payload = {
|
| 14 |
"id": spec.name,
|
| 15 |
"object": "model",
|
| 16 |
"owned_by": "owner",
|
| 17 |
"permission": [],
|
| 18 |
}
|
| 19 |
+
if include_metadata:
|
| 20 |
+
metadata = spec.metadata.to_dict() if spec.metadata else {"description": "No additional details provided."}
|
| 21 |
+
metadata.setdefault("huggingface_repo", spec.hf_repo)
|
| 22 |
+
if spec.max_context_tokens is not None:
|
| 23 |
+
metadata.setdefault("max_context_tokens", spec.max_context_tokens)
|
| 24 |
+
if spec.dtype:
|
| 25 |
+
metadata.setdefault("dtype", spec.dtype)
|
| 26 |
+
if spec.device:
|
| 27 |
+
metadata.setdefault("default_device", spec.device)
|
| 28 |
+
payload["metadata"] = metadata
|
| 29 |
return payload
|
| 30 |
|
| 31 |
|
|
|
|
| 41 |
spec = get_model_spec(model_id)
|
| 42 |
except KeyError:
|
| 43 |
raise model_not_found(model_id)
|
| 44 |
+
return _serialize_model(spec, include_metadata=True)
|
tests/test_openai_compat.py
CHANGED
|
@@ -260,6 +260,6 @@ def test_model_detail_serialization(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
| 260 |
monkeypatch.setattr(models, "get_model_spec", lambda _: spec)
|
| 261 |
|
| 262 |
listing = models.list_available_models()
|
| 263 |
-
assert listing["data"][0]
|
| 264 |
detail = models.retrieve_model("example")
|
| 265 |
assert detail["metadata"]["huggingface_repo"] == "example/repo"
|
|
|
|
| 260 |
monkeypatch.setattr(models, "get_model_spec", lambda _: spec)
|
| 261 |
|
| 262 |
listing = models.list_available_models()
|
| 263 |
+
assert "metadata" not in listing["data"][0]
|
| 264 |
detail = models.retrieve_model("example")
|
| 265 |
assert detail["metadata"]["huggingface_repo"] == "example/repo"
|