feat(ai-runtime): complete ai runtime policy refactor (ADR-035)
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/.env.template
|
||||
# Change Log:
|
||||
# - 2026-06-11: สร้างไฟล์ env template สำหรับ Desk-5439 (US5)
|
||||
|
||||
# ─── VRAM, Residency & Timeout Configurations ───
|
||||
VRAM_HEADROOM_THRESHOLD_MB=3000.0
|
||||
OCR_RESIDENCY_WINDOW_SECONDS=120
|
||||
GPU_TOTAL_VRAM_MB=16384.0
|
||||
GPU_MAIN_MODEL_PRESSURE_THRESHOLD_MB=12000.0
|
||||
RETRIEVAL_TIMEOUT_SECONDS=30.0
|
||||
|
||||
# ─── Queue policy & concurrency ───
|
||||
REALTIME_CONCURRENCY=2
|
||||
+5
-7
@@ -1,12 +1,10 @@
|
||||
FROM scb10x/typhoon2.5-qwen3-4b:latest
|
||||
|
||||
|
||||
|
||||
PARAMETER num\_ctx 8192
|
||||
PARAMETER num\_predict 4096
|
||||
PARAMETER num_ctx 8192
|
||||
PARAMETER num_predict 4096
|
||||
PARAMETER temperature 0.4
|
||||
|
||||
PARAMETER top\_k 40
|
||||
PARAMETER top\_p 0.9
|
||||
PARAMETER repeat\_penalty 1.15
|
||||
PARAMETER top_k 40
|
||||
PARAMETER top_p 0.9
|
||||
PARAMETER repeat_penalty 1.15
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose\Desk-5439\ocr-sidecar\app.py
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/ocr-sidecar/app.py
|
||||
# Typhoon OCR HTTP Sidecar API — รับ POST /ocr แล้วคืนข้อความที่สกัดจาก PDF/Image
|
||||
# ตาม ADR-023A (revised 2026-06-11): ใช้ typhoon_ocr library + np-dms-ocr (Ollama) แทน Tesseract
|
||||
# Change Log:
|
||||
@@ -21,6 +21,7 @@
|
||||
# - 2026-06-05: เพิ่ม Option 2 (aggressive preprocessing: deskew + Otsu threshold + morphology) และ Option 3 (smart post-processing: regex-based hallucination removal) เพื่อลด Tesseract noise/hallucination (T025)
|
||||
# - 2026-06-06: เปลี่ยน keep_alive จาก 300s เป็น 0 เพื่อ unload model ทันทีหลังเสร็จงาน (แก้ปัญหา VRAM ไม่พอเมื่อ typhoon2.5-np-dms load พร้อมกัน)
|
||||
# - 2026-06-11: เปลี่ยน process_with_typhoon_ocr ให้ใช้ prepare_ocr_messages จาก typhoon_ocr library + inject DMS tags; เปลี่ยน endpoint เป็น /v1/chat/completions
|
||||
# - 2026-06-11: US2 & US3 - เพิ่ม keep_alive parameter และ CPU fallback สำหรับ /embed และ /rerank
|
||||
|
||||
import os
|
||||
import logging
|
||||
@@ -30,11 +31,13 @@ import json
|
||||
import tempfile
|
||||
import fitz # PyMuPDF (ใช้สำหรับ page count + fast-path text extraction)
|
||||
import httpx
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
import io
|
||||
from typhoon_ocr import prepare_ocr_messages
|
||||
from services.vram_monitor import get_vram_headroom
|
||||
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, Security, status
|
||||
from fastapi.security.api_key import APIKeyHeader
|
||||
@@ -104,6 +107,7 @@ class OcrRequest(BaseModel):
|
||||
pdfPath: str
|
||||
maxPages: Optional[int] = None
|
||||
engine: Optional[str] = None
|
||||
keep_alive: Optional[int] = None
|
||||
|
||||
class OcrResponse(BaseModel):
|
||||
text: str
|
||||
@@ -211,7 +215,7 @@ def process_with_typhoon_ocr(pdf_path: str, page_num: int = 1, options_override:
|
||||
"repetition_penalty": options_override.get("repeat_penalty", 1.2),
|
||||
"temperature": options_override.get("temperature", 0.1),
|
||||
"top_p": options_override.get("top_p", 0.6),
|
||||
"keep_alive": 0, # Unload model ทันทีหลังเสร็จงานเพื่อคืน VRAM ให้ np-dms-ai ใช้งานได้
|
||||
"keep_alive": options_override.get("keep_alive", 0), # Unload model ทันทีหลังเสร็จงานเพื่อคืน VRAM ให้ np-dms-ai ใช้งานได้
|
||||
}
|
||||
# ใช้ Ollama OpenAI-compatible endpoint (/v1/chat/completions)
|
||||
with httpx.Client(timeout=TYPHOON_OCR_TIMEOUT) as client:
|
||||
@@ -249,11 +253,14 @@ def ocr_extract(req: OcrRequest):
|
||||
raise HTTPException(status_code=404, detail=f"ไม่พบไฟล์: {req.pdfPath}")
|
||||
selected_engine = (req.engine or "auto").strip().lower()
|
||||
max_pages = req.maxPages or MAX_PAGES
|
||||
typhoon_options = {}
|
||||
if req.keep_alive is not None:
|
||||
typhoon_options["keep_alive"] = req.keep_alive
|
||||
try:
|
||||
doc = fitz.open(str(pdf_path))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"เปิดไฟล์ PDF ล้มเหลว: {e}")
|
||||
return _process_pdf_doc(doc, selected_engine, max_pages)
|
||||
return _process_pdf_doc(doc, selected_engine, max_pages, typhoon_options)
|
||||
|
||||
@app.post("/ocr-upload", response_model=OcrResponse, dependencies=[Depends(get_api_key)])
|
||||
def ocr_upload(
|
||||
@@ -263,6 +270,7 @@ def ocr_upload(
|
||||
temperature: Optional[float] = Form(default=None),
|
||||
topP: Optional[float] = Form(default=None),
|
||||
repeatPenalty: Optional[float] = Form(default=None),
|
||||
keep_alive: Optional[int] = Form(default=None),
|
||||
):
|
||||
"""OCR จาก multipart file upload — ไม่ต้องการ shared volume mount"""
|
||||
selected_engine = engine.strip().lower()
|
||||
@@ -275,6 +283,8 @@ def ocr_upload(
|
||||
typhoon_options["top_p"] = topP
|
||||
if repeatPenalty is not None:
|
||||
typhoon_options["repeat_penalty"] = repeatPenalty
|
||||
if keep_alive is not None:
|
||||
typhoon_options["keep_alive"] = keep_alive
|
||||
pdf_bytes = file.file.read()
|
||||
import tempfile
|
||||
tmp_pdf_path: str | None = None
|
||||
@@ -317,6 +327,7 @@ class EmbedRequest(BaseModel):
|
||||
class EmbedResponse(BaseModel):
|
||||
dense: list[float]
|
||||
sparse: dict
|
||||
device: Optional[str] = None
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
query: str
|
||||
@@ -325,54 +336,133 @@ class RerankRequest(BaseModel):
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[float]
|
||||
ranked_indices: list[int]
|
||||
device: Optional[str] = None
|
||||
|
||||
@app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(get_api_key)])
|
||||
def embed_text(req: EmbedRequest):
|
||||
"""BGE-M3 embedding generator (Dense + Sparse)"""
|
||||
async def embed_text(req: EmbedRequest):
|
||||
"""BGE-M3 embedding generator (Dense + Sparse) พร้อม CPU fallback และ timeout guard"""
|
||||
if bge_model is None:
|
||||
raise HTTPException(status_code=503, detail="BGE-M3 model not loaded")
|
||||
threshold_mb = float(os.getenv("VRAM_HEADROOM_THRESHOLD_MB", "3000.0"))
|
||||
timeout_sec = float(os.getenv("RETRIEVAL_TIMEOUT_SECONDS", "30.0"))
|
||||
headroom = get_vram_headroom()
|
||||
device = "cuda"
|
||||
reason = "headroom-sufficient"
|
||||
if not headroom.query_success:
|
||||
device = "cpu"
|
||||
reason = "gpu-query-failed"
|
||||
elif headroom.available_mb < threshold_mb:
|
||||
device = "cpu"
|
||||
reason = "gpu-headroom-below-threshold"
|
||||
try:
|
||||
if device == "cuda":
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
bge_model.model.to("cuda")
|
||||
else:
|
||||
device = "cpu"
|
||||
reason = "cuda-not-available"
|
||||
bge_model.model.to("cpu")
|
||||
else:
|
||||
bge_model.model.to("cpu")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move BGE-M3 model to {device}: {e}")
|
||||
device = "cpu"
|
||||
reason = f"device-move-failed: {str(e)}"
|
||||
try:
|
||||
bge_model.model.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Embedding on device: {device} (reason: {reason})")
|
||||
def run_inference():
|
||||
output = bge_model.encode([req.text], return_dense=True, return_sparse=True)
|
||||
dense_vector = [float(x) for x in output['dense_vecs'][0]]
|
||||
lexical_dict = output['lexical_weights'][0]
|
||||
|
||||
indices = []
|
||||
values = []
|
||||
for token_id, weight in lexical_dict.items():
|
||||
indices.append(int(token_id))
|
||||
values.append(float(weight))
|
||||
|
||||
return dense_vector, indices, values
|
||||
try:
|
||||
dense_vector, indices, values = await asyncio.wait_for(
|
||||
asyncio.to_thread(run_inference),
|
||||
timeout=timeout_sec
|
||||
)
|
||||
return EmbedResponse(
|
||||
dense=dense_vector,
|
||||
sparse={"indices": indices, "values": values}
|
||||
sparse={"indices": indices, "values": values},
|
||||
device=device
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Embedding generation timed out after {timeout_sec}s on device {device}")
|
||||
raise HTTPException(status_code=504, detail="Embedding generation timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Embedding generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}")
|
||||
|
||||
@app.post("/rerank", response_model=RerankResponse, dependencies=[Depends(get_api_key)])
|
||||
def rerank_chunks(req: RerankRequest):
|
||||
"""BGE-Reranker-Large chunk re-ranker"""
|
||||
async def rerank_chunks(req: RerankRequest):
|
||||
"""BGE-Reranker-Large chunk re-ranker พร้อม CPU fallback และ timeout guard"""
|
||||
if reranker is None:
|
||||
raise HTTPException(status_code=503, detail="Reranker model not loaded")
|
||||
if not req.chunks:
|
||||
return RerankResponse(scores=[], ranked_indices=[])
|
||||
return RerankResponse(scores=[], ranked_indices=[], device="cpu")
|
||||
threshold_mb = float(os.getenv("VRAM_HEADROOM_THRESHOLD_MB", "3000.0"))
|
||||
timeout_sec = float(os.getenv("RETRIEVAL_TIMEOUT_SECONDS", "30.0"))
|
||||
headroom = get_vram_headroom()
|
||||
device = "cuda"
|
||||
reason = "headroom-sufficient"
|
||||
if not headroom.query_success:
|
||||
device = "cpu"
|
||||
reason = "gpu-query-failed"
|
||||
elif headroom.available_mb < threshold_mb:
|
||||
device = "cpu"
|
||||
reason = "gpu-headroom-below-threshold"
|
||||
try:
|
||||
if device == "cuda":
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
reranker.model.to("cuda")
|
||||
else:
|
||||
device = "cpu"
|
||||
reason = "cuda-not-available"
|
||||
reranker.model.to("cpu")
|
||||
else:
|
||||
reranker.model.to("cpu")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to move Reranker model to {device}: {e}")
|
||||
device = "cpu"
|
||||
reason = f"device-move-failed: {str(e)}"
|
||||
try:
|
||||
reranker.model.to("cpu")
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Reranking on device: {device} (reason: {reason})")
|
||||
def run_rerank():
|
||||
pairs = [[req.query, chunk] for chunk in req.chunks]
|
||||
scores = reranker.compute_score(pairs)
|
||||
if isinstance(scores, float):
|
||||
scores = [scores]
|
||||
else:
|
||||
scores = [float(s) for s in scores]
|
||||
|
||||
indexed_scores = list(enumerate(scores))
|
||||
indexed_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
ranked_indices = [idx for idx, _ in indexed_scores]
|
||||
|
||||
return scores, ranked_indices
|
||||
try:
|
||||
scores, ranked_indices = await asyncio.wait_for(
|
||||
asyncio.to_thread(run_rerank),
|
||||
timeout=timeout_sec
|
||||
)
|
||||
return RerankResponse(
|
||||
scores=scores,
|
||||
ranked_indices=ranked_indices
|
||||
ranked_indices=ranked_indices,
|
||||
device=device
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Reranking timed out after {timeout_sec}s on device {device}")
|
||||
raise HTTPException(status_code=504, detail="Reranking timed out")
|
||||
except Exception as e:
|
||||
logger.error(f"Reranking failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Reranking failed: {str(e)}")
|
||||
|
||||
+7
@@ -13,6 +13,7 @@
|
||||
# - 2026-06-04: ADR-034 — เปลี่ยน TYPHOON_OCR_MODEL เป็น typhoon-np-dms-ocr:latest; OLLAMA_API_URL ชี้ตรงไป Ollama (ไม่ผ่าน metrics proxy) เพื่อป้องกัน empty response
|
||||
# - 2026-06-02: เพิ่ม ollama-metrics (NorskHelsenett) — Prometheus sidecar สำหรับ Ollama metrics
|
||||
# expose /metrics บน port 9924; Prometheus (ASUSTOR) scrape จาก 192.168.10.100:9924
|
||||
# - 2026-06-11: US2 & US3 - เพิ่ม VRAM headroom, residency window, pressure threshold, retrieval timeout env variables
|
||||
#
|
||||
# วิธีรัน:
|
||||
# docker compose up -d --build
|
||||
@@ -45,6 +46,12 @@ services:
|
||||
TYPHOON_OCR_MODEL: "typhoon-np-dms-ocr:latest"
|
||||
# Timeout 360 วินาที/หน้า — รองรับ cold-start โหลด model (~70s) + inference (10GB model, CPU offload)
|
||||
TYPHOON_OCR_TIMEOUT: "360"
|
||||
# ─── VRAM, Residency & Timeout Configurations (Feature-235) ──────────────
|
||||
VRAM_HEADROOM_THRESHOLD_MB: "3000.0"
|
||||
OCR_RESIDENCY_WINDOW_SECONDS: "120"
|
||||
GPU_TOTAL_VRAM_MB: "16384.0"
|
||||
GPU_MAIN_MODEL_PRESSURE_THRESHOLD_MB: "12000.0"
|
||||
RETRIEVAL_TIMEOUT_SECONDS: "30.0"
|
||||
logging:
|
||||
driver: "json-file"
|
||||
options:
|
||||
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/ocr-sidecar/services/residency_policy.py
|
||||
# Change Log:
|
||||
# - 2026-06-11: Initial creation of residency_policy.py for calculating OCR keep_alive value dynamically
|
||||
|
||||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from services.vram_monitor import get_vram_headroom
|
||||
|
||||
logger = logging.getLogger("ocr-sidecar.residency-policy")
|
||||
|
||||
@dataclass
|
||||
class OcrResidencyDecision:
|
||||
keep_alive_seconds: int
|
||||
vram_headroom_mb: float
|
||||
reason: str
|
||||
|
||||
def calculate_ocr_residency(active_profile: str = None) -> OcrResidencyDecision:
|
||||
"""
|
||||
คำนวณ keep_alive สำหรับ Typhoon OCR จาก VRAM headroom และ active profile ของโมเดลหลัก
|
||||
"""
|
||||
threshold_mb = float(os.getenv("VRAM_HEADROOM_THRESHOLD_MB", "3000.0"))
|
||||
residency_window = int(os.getenv("OCR_RESIDENCY_WINDOW_SECONDS", "120"))
|
||||
pressure_threshold = float(os.getenv("GPU_MAIN_MODEL_PRESSURE_THRESHOLD_MB", "7000.0"))
|
||||
if active_profile in ("deep-analysis", "large-context"):
|
||||
return OcrResidencyDecision(0, -1.0, "large-context-active")
|
||||
headroom = get_vram_headroom()
|
||||
if not headroom.query_success:
|
||||
return OcrResidencyDecision(0, -1.0, "query-failed")
|
||||
if headroom.used_mb > pressure_threshold:
|
||||
return OcrResidencyDecision(0, headroom.available_mb, "high-pressure")
|
||||
if headroom.available_mb < threshold_mb:
|
||||
return OcrResidencyDecision(0, headroom.available_mb, "high-pressure")
|
||||
return OcrResidencyDecision(residency_window, headroom.available_mb, "headroom-sufficient")
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/ocr-sidecar/services/vram_monitor.py
|
||||
# Change Log:
|
||||
# - 2026-06-11: Initial creation of VramMonitor service for Python OCR sidecar to query GPU VRAM headroom from Ollama /api/ps
|
||||
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("ocr-sidecar.vram-monitor")
|
||||
|
||||
@dataclass
|
||||
class VramHeadroom:
|
||||
total_mb: float
|
||||
used_mb: float
|
||||
available_mb: float
|
||||
query_success: bool
|
||||
|
||||
def get_vram_headroom() -> VramHeadroom:
|
||||
"""
|
||||
ดึงข้อมูล VRAM headroom จาก Ollama /api/ps
|
||||
และคำนวณพื้นที่คงเหลือใน VRAM เพื่อประกอบการตัดสินใจเรื่อง Residency และ CPU Fallback
|
||||
"""
|
||||
ollama_url = os.getenv("OLLAMA_API_URL", "http://host.docker.internal:11434")
|
||||
total_vram_mb = float(os.getenv("GPU_TOTAL_VRAM_MB", "16384.0"))
|
||||
try:
|
||||
# ดึงสถานะ running models จาก Ollama
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
response = client.get(f"{ollama_url}/api/ps")
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Ollama ps endpoint returned status code: {response.status_code}")
|
||||
return VramHeadroom(total_vram_mb, total_vram_mb, 0.0, False)
|
||||
data = response.json()
|
||||
models = data.get("models", [])
|
||||
total_used_bytes = 0
|
||||
for model in models:
|
||||
total_used_bytes += model.get("size_vram", 0)
|
||||
used_mb = float(total_used_bytes) / (1024.0 * 1024.0)
|
||||
available_mb = max(0.0, total_vram_mb - used_mb)
|
||||
return VramHeadroom(total_vram_mb, used_mb, available_mb, True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to query Ollama VRAM: {str(e)}")
|
||||
return VramHeadroom(total_vram_mb, total_vram_mb, 0.0, False)
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/ocr-sidecar/tests/test_retrieval_fallback.py
|
||||
# Change Log:
|
||||
# - 2026-06-11: Initial integration tests for retrieval fallback using pytest
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
# Setup env variables before importing app
|
||||
os.environ["OCR_SIDECAR_API_KEY"] = "test-key"
|
||||
os.environ["VRAM_HEADROOM_THRESHOLD_MB"] = "3000.0"
|
||||
os.environ["RETRIEVAL_TIMEOUT_SECONDS"] = "2.0"
|
||||
|
||||
from app import app, EmbedRequest, RerankRequest, get_api_key
|
||||
|
||||
client = TestClient(app)
|
||||
API_HEADERS = {"X-API-Key": "test-key"}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bge_model():
|
||||
with patch("app.bge_model") as mock:
|
||||
mock.model = MagicMock()
|
||||
mock.encode.return_value = {
|
||||
"dense_vecs": [[0.1, 0.2]],
|
||||
"lexical_weights": [{"101": 0.5}]
|
||||
}
|
||||
yield mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reranker():
|
||||
with patch("app.reranker") as mock:
|
||||
mock.model = MagicMock()
|
||||
mock.compute_score.return_value = [0.85]
|
||||
yield mock
|
||||
|
||||
def test_embed_gpu_when_headroom_sufficient(mock_bge_model):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=2000.0, available_mb=14384.0, query_success=True)
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock), \
|
||||
patch("torch.cuda.is_available", return_value=True):
|
||||
response = client.post("/embed", json={"text": "hello test"}, headers=API_HEADERS)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["device"] == "cuda"
|
||||
mock_bge_model.model.to.assert_called_with("cuda")
|
||||
|
||||
def test_embed_cpu_when_headroom_insufficient(mock_bge_model):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=14000.0, available_mb=2384.0, query_success=True)
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock):
|
||||
response = client.post("/embed", json={"text": "hello test"}, headers=API_HEADERS)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["device"] == "cpu"
|
||||
mock_bge_model.model.to.assert_called_with("cpu")
|
||||
|
||||
def test_embed_cpu_when_gpu_query_failed(mock_bge_model):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=16384.0, available_mb=0.0, query_success=False)
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock):
|
||||
response = client.post("/embed", json={"text": "hello test"}, headers=API_HEADERS)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["device"] == "cpu"
|
||||
mock_bge_model.model.to.assert_called_with("cpu")
|
||||
|
||||
def test_embed_timeout_returns_504(mock_bge_model):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=2000.0, available_mb=14384.0, query_success=True)
|
||||
# Mock encode to simulate a slow run
|
||||
def slow_encode(*args, **kwargs):
|
||||
import time
|
||||
time.sleep(3.0)
|
||||
return {"dense_vecs": [[0.1]], "lexical_weights": [{"1": 0.1}]}
|
||||
mock_bge_model.encode.side_effect = slow_encode
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock):
|
||||
response = client.post("/embed", json={"text": "hello test"}, headers=API_HEADERS)
|
||||
assert response.status_code == 504
|
||||
|
||||
def test_rerank_gpu_when_headroom_sufficient(mock_reranker):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=2000.0, available_mb=14384.0, query_success=True)
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock), \
|
||||
patch("torch.cuda.is_available", return_value=True):
|
||||
response = client.post("/rerank", json={"query": "test query", "chunks": ["chunk1"]}, headers=API_HEADERS)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["device"] == "cuda"
|
||||
mock_reranker.model.to.assert_called_with("cuda")
|
||||
|
||||
def test_rerank_cpu_when_headroom_insufficient(mock_reranker):
|
||||
vram_mock = MagicMock(total_mb=16384.0, used_mb=14000.0, available_mb=2384.0, query_success=True)
|
||||
with patch("app.get_vram_headroom", return_value=vram_mock):
|
||||
response = client.post("/rerank", json={"query": "test query", "chunks": ["chunk1"]}, headers=API_HEADERS)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["device"] == "cpu"
|
||||
mock_reranker.model.to.assert_called_with("cpu")
|
||||
Reference in New Issue
Block a user