# File: specs/04-Infrastructure-OPS/04-00-docker-compose/Desk-5439/ocr-sidecar/app.py # OCR HTTP Sidecar API — รับ POST /ocr แล้วคืนข้อความที่สกัดจาก PDF/Image # ตาม ADR-023A (revised 2026-06-11): ใช้ np-dms-ocr (Ollama) แทน Tesseract # Change Log: # - 2026-05-25: Initial FastAPI server สำหรับ Tesseract OCR sidecar # - 2026-05-30: เปลี่ยน lang='en' เป็น lang='ch' (CTJK) เพื่อรองรับภาษาไทย # - 2026-05-30: เปลี่ยนจาก PaddleOCR เป็น Tesseract OCR เพื่อความเข้ากันได้กับ CPU เก่า # - 2026-05-30: เพิ่ม OpenCV preprocessing (threshold, denoise) และ DPI 300 เพื่อเพิ่มความแม่นยำ # - 2026-06-01: เพิ่ม POST /ocr-upload รับ multipart file โดยตรง ไม่ต้องพึ่ง shared volume mount # - 2026-06-01: เปลี่ยน OCR_MODEL default เป็น scb10x/typhoon-ocr1.5-3b # - 2026-06-02: เพิ่มตัวเลือกสลับโมเดลใน process_ocr ตามพารามิเตอร์ engine และตั้ง engineUsed ให้ตรงตามจริง (T015, ADR-033) # - 2026-06-04: ADR-034 — เพิ่ม np-dms-ocr เป็น canonical engine key; default OCR_MODEL เป็น np-dms-ocr:latest; alias โมเดลเก่ายังคงไว้ # - 2026-06-04: ให้ SYSTEM ใน Modelfile ทำงานแทน — ลบ prompt ซ้าซ้อน; sync options ให้ตรงกับ Modelfile (temperature 0.1, top_p 0.1, repeat_penalty 1.1) # - 2026-06-04: รับค่า temperature/top_p/repeat_penalty จาก frontend sandbox ได้ (optional override) # - 2026-06-04: แก้ bug prompt="" ทำให้ Ollama ไม่ generate — เปลี่ยนเป็น minimal trigger prompt # - 2026-06-04: เพิ่ม alias normalization สำหรับ engine name เก่า (typhoon-ocr1.5-3b → np-dms-ocr) # - 2026-06-04: เพิ่ม OCR_DPI=150 (แยกจาก Tesseract DPI=300) — ลด image token count 4x เพื่อเร่ง CPU inference (model >8GB ไม่พอ VRAM) # - 2026-06-04: ส่ง color image (ไม่ผ่าน preprocess_image) ไปยัง np-dms-ocr — vision model ต้องการ color ไม่ใช่ binarized grayscale # - 2026-06-04: เพิ่ม num_gpu:99 ใน Ollama options เพื่อบังคับ GPU layers (แก้ device=CPU ทั้งที่ VRAM พอ) # - 2026-06-02: เพิ่มการตรวจสอบ API Key (X-API-Key Header) สำหรับ endpoints หลัก เพื่อความมั่นคงปลอดภัยตามข้อเสนอแนะ Code Review # - 2026-06-05: เพิ่ม Option 2 (aggressive preprocessing: deskew + Otsu threshold + morphology) และ Option 3 (smart post-processing: regex-based hallucination removal) เพื่อลด Tesseract noise/hallucination (T025) # - 2026-06-06: เปลี่ยน keep_alive จาก 300s เป็น 0 เพื่อ unload model ทันทีหลังเสร็จงาน (แก้ปัญหา VRAM ไม่พอเมื่อ np-dms-ai load พร้อมกัน) # - 2026-06-11: เปลี่ยน process_ocr ให้ใช้ prepare_ocr_messages จาก typhoon_ocr library + inject DMS tags; เปลี่ยน endpoint เป็น /v1/chat/completions # - 2026-06-11: US2 & US3 - เพิ่ม keep_alive parameter และ CPU fallback สำหรับ /embed และ /rerank # - 2026-06-13: ADR-036 — เปลี่ยน canonical engine/model เป็น np-dms-ocr และคง legacy aliases # - 2026-06-17: เปลี่ยนชื่อ environment variable จาก TYPHOON_OCR_MODEL → OCR_MODEL และ TYPHOON_OCR_TIMEOUT → OCR_TIMEOUT เพื่อ consistency กับ ADR-036 # - 2026-06-17: ลบชื่อ Typhoon ออกจากทุกส่วน: process_with_typhoon_ocr → process_ocr, FastAPI title, comments, ตัวแปรต่างๆ # - 2026-06-17: เพิ่ม systemPrompt parameter ใน /ocr-upload, _process_pdf_doc, process_ocr เพื่อรองรับ dynamic OCR system prompt injection (T026-T028) # - 2026-06-18: เพิ่ม MAX_SYSTEM_PROMPT_LENGTH environment variable สำหรับ configurable validation (fix-3) # - 2026-06-20: ADR-040 Phase 1-4 — ลบ default API key, เพิ่ม path whitelist, และ wire adaptive OCR residency # - 2026-06-20: ADR-040 Phase 6 — async I/O refactor: async process_ocr, AsyncClient via lifespan, asyncio.to_thread model loading # - 2026-06-20: ADR-040 Phase 8 — ลบ /normalize endpoint (ไม่มี consumers) และ pythainlp imports import os import logging import re import json import tempfile import fitz # PyMuPDF (ใช้สำหรับ page count + fast-path text extraction) import httpx import asyncio from contextlib import asynccontextmanager from pathlib import Path from typing import Optional from typhoon_ocr import prepare_ocr_messages # External library from SCB10X (PyPI) — provides OCR message preparation for np-dms-ocr from services.vram_monitor import get_vram_headroom from services.residency_policy import calculate_ocr_residency from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Depends, Security, status from fastapi.security.api_key import APIKeyHeader from pydantic import BaseModel from FlagEmbedding import BGEM3FlagModel, FlagReranker logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ocr-sidecar") # Initialize BGE-M3 and Reranker singletons bge_model = None reranker = None # Shared AsyncClient สำหรับ Ollama API (T043: สร้างใน lifespan context manager) ollama_client: httpx.AsyncClient | None = None def _load_bge_models() -> tuple: """โหลด BGE-M3 และ Reranker models บน CPU RAM (T046: เรียกผ่าน asyncio.to_thread)""" logger.info("Loading BGE-M3 and Reranker models on CPU RAM...") try: bge = BGEM3FlagModel('BAAI/bge-m3', use_fp16=False) rerank = FlagReranker('BAAI/bge-reranker-large', use_fp16=False) logger.info("BGE-M3 and Reranker models loaded successfully.") return bge, rerank except Exception as e: logger.error(f"Failed to load BGE models: {e}") return None, None @asynccontextmanager async def lifespan(app_instance: FastAPI): """T043/T045: Lifespan context manager แทน @app.on_event('startup') — จัดการ AsyncClient และ model loading""" global bge_model, reranker, ollama_client # T043: สร้าง shared AsyncClient สำหรับ Ollama API ollama_client = httpx.AsyncClient(timeout=OCR_TIMEOUT) logger.info(f"Shared AsyncClient created (timeout={OCR_TIMEOUT}s)") # T046: โหลด models ผ่าน asyncio.to_thread เพื่อไม่ block startup bge_model, reranker = await asyncio.to_thread(_load_bge_models) yield # Cleanup: ปิด AsyncClient if ollama_client: await ollama_client.aclose() logger.info("Shared AsyncClient closed.") app = FastAPI(title="OCR Sidecar", version="2.0.0", lifespan=lifespan) # กำหนดค่าโทเค็นความปลอดภัยของ Sidecar ตามข้อเสนอแนะในการรักษาความมั่นคงปลอดภัย OCR_SIDECAR_API_KEY = os.getenv("OCR_SIDECAR_API_KEY") if not OCR_SIDECAR_API_KEY: raise RuntimeError("OCR_SIDECAR_API_KEY is required for OCR sidecar startup") # กำหนดค่าความยาวสูงสุดของ systemPrompt (fix-3: configurable validation) MAX_SYSTEM_PROMPT_LENGTH = int(os.getenv("MAX_SYSTEM_PROMPT_LENGTH", "10000")) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) async def get_api_key(api_key: str = Security(api_key_header)): if not api_key: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing API Key in request headers (X-API-Key)") if api_key != OCR_SIDECAR_API_KEY: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key") return api_key # อ่านค่า config จาก environment OCR_CHAR_THRESHOLD = int(os.getenv("OCR_CHAR_THRESHOLD", "100")) MAX_PAGES = int(os.getenv("OCR_MAX_PAGES", "0")) # 0 = ทุกหน้า OLLAMA_API_URL = os.getenv("OLLAMA_API_URL", "http://host.docker.internal:11434") OCR_MODEL = os.getenv("OCR_MODEL", "np-dms-ocr:latest") OCR_TIMEOUT = int(os.getenv("OCR_TIMEOUT", "360")) # รองรับ cold-start ~65s + inference ~30s/page OCR_SIDECAR_UPLOAD_BASE = os.getenv("OCR_SIDECAR_UPLOAD_BASE", "/mnt/uploads") OCR_ACTIVE_PROFILE = os.getenv("OCR_ACTIVE_PROFILE") logger.info(f"OCR Sidecar initialized (model={OCR_MODEL}, ollama={OLLAMA_API_URL})") def filter_ocr_noise(text: str) -> str: """กรองสัญลักษณ์ที่ไม่มีความหมายออกจาก Markdown output""" lines = text.split("\n") filtered = [] for line in lines: line = line.strip() if not line: continue alphanumeric_part = re.sub(r'[^\w\u0E00-\u0E7F]', '', line) if len(alphanumeric_part) < 2: continue filtered.append(line) return "\n".join(filtered) def validate_pdf_path(pdf_path: str) -> Path: """Canonicalize path และยืนยันว่าอยู่ใต้ OCR_SIDECAR_UPLOAD_BASE""" canonical_path = os.path.abspath(os.path.realpath(pdf_path)) canonical_base = os.path.abspath(os.path.realpath(OCR_SIDECAR_UPLOAD_BASE)) try: common_path = os.path.commonpath([canonical_path, canonical_base]) except ValueError: common_path = "" if common_path != canonical_base: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Path outside whitelisted base directory", ) return Path(canonical_path) class OcrRequest(BaseModel): pdfPath: str maxPages: Optional[int] = None engine: Optional[str] = None keep_alive: Optional[int] = None runtime_params: Optional[dict] = None system_prompt: Optional[str] = None dms_tags: Optional[dict] = None class OcrResponse(BaseModel): text: str ocrUsed: bool pageCount: int charCount: int engineUsed: str @app.get("/health") def health(): return { "status": "ok", "engine": "np-dms-ocr", "ocrModel": OCR_MODEL, "ollamaUrl": OLLAMA_API_URL, } async def _process_pdf_doc( doc: fitz.Document, selected_engine: str, max_pages: int, ocr_options: Optional[dict] = None, pdf_path: str | None = None, system_prompt: Optional[str] = None, runtime_params: Optional[dict] = None, dms_tags: Optional[dict] = None, ) -> OcrResponse: """ประมวลผล fitz.Document ด้วย engine ที่เลือก — shared logic สำหรับ /ocr และ /ocr-upload""" ocr_options = ocr_options or {} pages_to_process = list(range(min(len(doc), max_pages) if max_pages > 0 else len(doc))) page_count = len(pages_to_process) fast_text_parts = [] total_chars = 0 if selected_engine == "auto": for i in pages_to_process: page = doc[i] fast_text_parts.append(page.get_text()) fast_text = "\n".join(fast_text_parts).strip() total_chars = len(fast_text) if total_chars > OCR_CHAR_THRESHOLD: logger.info(f"Fast path: {total_chars} chars extracted") return OcrResponse( text=fast_text, ocrUsed=False, pageCount=page_count, charCount=total_chars, engineUsed="fast-path", ) if selected_engine == "np-dms-ocr": # ใช้ prepare_ocr_messages รับ PDF path โดยตรง — ไม่ต้องแปลง PIL Image อีกต่อไป resolved_path = pdf_path or (str(doc.name) if hasattr(doc, 'name') and doc.name else None) if not resolved_path: raise ValueError("ไม่สามารถหา PDF path — ต้องส่ง pdf_path เข้ามาด้วย") ocr_text_parts = [] for i in pages_to_process: ocr_text_parts.append( await process_ocr( resolved_path, page_num=i + 1, options_override=ocr_options, system_prompt=system_prompt, runtime_params=runtime_params, dms_tags=dms_tags, ) ) ocr_text = filter_ocr_noise("\n".join(ocr_text_parts).strip()) return OcrResponse( text=ocr_text, ocrUsed=True, pageCount=page_count, charCount=len(ocr_text), engineUsed=selected_engine, ) # ถ้าไม่ใช่ engine ที่รู้จัก ให้ใช้ np-dms-ocr เป็น fallback logger.warning(f"Unknown engine '{selected_engine}' — fallback to np-dms-ocr") resolved_path = pdf_path or (str(doc.name) if hasattr(doc, 'name') and doc.name else None) if not resolved_path: raise ValueError("ไม่สามารถหา PDF path — ต้องส่ง pdf_path เข้ามาด้วย") fallback_parts = [] for i in pages_to_process: fallback_parts.append( await process_ocr( resolved_path, page_num=i + 1, options_override=ocr_options, system_prompt=system_prompt, runtime_params=runtime_params, dms_tags=dms_tags, ) ) fallback_text = filter_ocr_noise("\n".join(fallback_parts).strip()) return OcrResponse( text=fallback_text, ocrUsed=True, pageCount=page_count, charCount=len(fallback_text), engineUsed="np-dms-ocr", ) async def process_ocr( pdf_path: str, page_num: int = 1, options_override: Optional[dict] = None, system_prompt: Optional[str] = None, runtime_params: Optional[dict] = None, dms_tags: Optional[dict] = None, ) -> str: """เรียก np-dms-ocr ผ่าน Ollama /v1/chat/completions — รับ PDF path โดยตรง ไม่ต้องแปลง PIL Image""" options_override = options_override or {} if "keep_alive" in options_override: raise ValueError("keep_alive must be calculated by OCR residency policy") residency = await asyncio.to_thread(calculate_ocr_residency, OCR_ACTIVE_PROFILE) model_name = OCR_MODEL # prepare_ocr_messages จัดการ PDF → image ผ่าน poppler/pdftoppm ภายใน messages = prepare_ocr_messages(pdf_path, task_type="structure", page_num=page_num) # inject system prompt ถ้ามี (ก่อน DMS tags) if system_prompt: messages[0]["content"].append({"type": "text", "text": system_prompt}) # Dynamic dms_tags mapping to prompts if dms_tags: dms_text = "Additionally:\n" for key in dms_tags.keys(): readable_name = re.sub(r'(?...\n" dms_text += "If a field is not found, omit the tag." else: # Fallback to default DMS extraction tags dms_text = ( "Additionally:\n" "- Wrap document number with ...\n" "- Wrap document date with ...\n" "- Wrap received date with ...\n" "If a field is not found, omit the tag." ) # inject DMS-specific extraction tags ต่อท้าย content messages[0]["content"].append({ "type": "text", "text": dms_text, }) # Resolve runtime parameters: remove hardcoded fallback values from sidecar # Use empty dict if runtime_params not provided to allow Ollama Modelfile default params = {} if runtime_params: if hasattr(runtime_params, "dict"): params = runtime_params.dict() elif isinstance(runtime_params, dict): params = runtime_params # Options override (e.g., from Sandbox form parameter overrides) takes precedence merged_params = {} if params: merged_params.update(params) if options_override: merged_params.update(options_override) # ค่า default ตาม official; options_override ยัง override ได้บางส่วน logger.info( f"OCR residency decision: keep_alive={residency.keep_alive_seconds}s " f"reason={residency.reason} headroom={residency.vram_headroom_mb}MB" ) payload = { "model": model_name, "messages": messages, "stream": False, "keep_alive": residency.keep_alive_seconds, } # Only send keys to Ollama if they are defined in merged_params (to support Modelfile fallback) if "temperature" in merged_params and merged_params["temperature"] is not None: payload["temperature"] = float(merged_params["temperature"]) if "top_p" in merged_params and merged_params["top_p"] is not None: payload["top_p"] = float(merged_params["top_p"]) if "repeat_penalty" in merged_params and merged_params["repeat_penalty"] is not None: payload["repetition_penalty"] = float(merged_params["repeat_penalty"]) elif "repetition_penalty" in merged_params and merged_params["repetition_penalty"] is not None: payload["repetition_penalty"] = float(merged_params["repetition_penalty"]) if "max_tokens" in merged_params and merged_params["max_tokens"] is not None: payload["max_tokens"] = int(merged_params["max_tokens"]) # T044: ใช้ shared AsyncClient (ollama_client) แทน httpx.Client แบบ sync # ถ้า ollama_client ยังไม่ถูกสร้าง (เช่น unit test ที่เรียกตรง) ให้สร้างชั่วคราว client = ollama_client if client is None: client = httpx.AsyncClient(timeout=OCR_TIMEOUT) response = await client.post( f"{OLLAMA_API_URL}/v1/chat/completions", json=payload, headers={"Authorization": "Bearer ollama"}, ) response.raise_for_status() data = response.json() raw_text = str(data.get("choices", [{}])[0].get("message", {}).get("content", "")).strip() # parse JSON output จาก model (format: {"natural_text": "..."}) try: result_text = json.loads(raw_text).get("natural_text", raw_text) except (json.JSONDecodeError, AttributeError): result_text = raw_text logger.info( f"[DIAG] Ollama response — model={model_name} " f"textLen={len(result_text)} " f"done={data.get('done')} " f"done_reason={data.get('done_reason')} " f"eval_count={data.get('eval_count', 0)}" ) if not result_text: logger.warning( f"[DIAG] Ollama returned empty response — full response keys: {list(data.keys())}" ) # ปิด temporary client ถ้าสร้างชั่วคราว if ollama_client is None: await client.aclose() return result_text @app.post("/ocr", response_model=OcrResponse, dependencies=[Depends(get_api_key)]) async def ocr_extract(req: OcrRequest): """OCR จาก path (legacy — ใช้เมื่อ sidecar และ backend เข้าถึง storage เดียวกัน)""" if req.keep_alive is not None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="keep_alive is managed by OCR residency policy") pdf_path = validate_pdf_path(req.pdfPath) if not pdf_path.exists(): raise HTTPException(status_code=404, detail=f"ไม่พบไฟล์: {req.pdfPath}") selected_engine = (req.engine or "auto").strip().lower() max_pages = req.maxPages or MAX_PAGES ocr_options = {} try: doc = fitz.open(str(pdf_path)) except Exception as e: raise HTTPException(status_code=422, detail=f"เปิดไฟล์ PDF ล้มเหลว: {e}") return await _process_pdf_doc( doc, selected_engine, max_pages, ocr_options, pdf_path=str(pdf_path), system_prompt=req.system_prompt, runtime_params=req.runtime_params, dms_tags=req.dms_tags, ) @app.post("/ocr-upload", response_model=OcrResponse, dependencies=[Depends(get_api_key)]) async def ocr_upload( file: UploadFile = File(...), engine: str = Form(default="auto"), maxPages: int = Form(default=0), temperature: Optional[float] = Form(default=None), topP: Optional[float] = Form(default=None), repeatPenalty: Optional[float] = Form(default=None), maxTokens: Optional[int] = Form(default=None), keep_alive: Optional[int] = Form(default=None), systemPrompt: Optional[str] = Form(default=None), dmsTags: Optional[str] = Form(default=None), runtimeParams: Optional[str] = Form(default=None), ): """OCR จาก multipart file upload — ไม่ต้องการ shared volume mount""" # Validate systemPrompt ถ้ามีส่งมา (gap-1: sidecar validation) if systemPrompt is not None: systemPrompt = systemPrompt.strip() if not systemPrompt: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="systemPrompt cannot be empty if provided" ) if len(systemPrompt) > MAX_SYSTEM_PROMPT_LENGTH: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"systemPrompt exceeds maximum length of {MAX_SYSTEM_PROMPT_LENGTH} characters" ) selected_engine = engine.strip().lower() max_pages = maxPages or MAX_PAGES # Parse runtimeParams and dmsTags from form-data JSON strings if provided runtime_params_dict = {} if runtimeParams: try: runtime_params_dict = json.loads(runtimeParams) except Exception as e: logger.warning(f"Failed to parse runtimeParams JSON: {e}") dms_tags_dict = None if dmsTags: try: dms_tags_dict = json.loads(dmsTags) except Exception as e: logger.warning(f"Failed to parse dmsTags JSON: {e}") # รวม options override สำหรับ np-dms-ocr (ถ้า frontend ส่งมา) ocr_options: dict = {} if temperature is not None: ocr_options["temperature"] = temperature if topP is not None: ocr_options["top_p"] = topP if repeatPenalty is not None: ocr_options["repeat_penalty"] = repeatPenalty if maxTokens is not None: ocr_options["max_tokens"] = maxTokens if keep_alive is not None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="keep_alive is managed by OCR residency policy") pdf_bytes = file.file.read() tmp_pdf_path: str | None = None try: # บันทึก PDF เป็น temp file เพื่อให้ prepare_ocr_messages อ่านได้ผ่าน path with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp: tmp.write(pdf_bytes) tmp_pdf_path = tmp.name try: doc = fitz.open(stream=pdf_bytes, filetype="pdf") except Exception as e: raise HTTPException(status_code=422, detail=f"เปิดไฟล์ PDF ล้มเหลว: {e}") logger.info(f"OCR upload: {file.filename} engine={selected_engine} options={ocr_options or 'modelfile-defaults'}") return await _process_pdf_doc( doc, selected_engine, max_pages, ocr_options, pdf_path=tmp_pdf_path, system_prompt=systemPrompt, runtime_params=runtime_params_dict, dms_tags=dms_tags_dict, ) finally: if tmp_pdf_path: Path(tmp_pdf_path).unlink(missing_ok=True) class EmbedRequest(BaseModel): text: str class EmbedResponse(BaseModel): dense: list[float] sparse: dict device: Optional[str] = None class RerankRequest(BaseModel): query: str chunks: list[str] class RerankResponse(BaseModel): scores: list[float] ranked_indices: list[int] device: Optional[str] = None @app.post("/embed", response_model=EmbedResponse, dependencies=[Depends(get_api_key)]) async def embed_text(req: EmbedRequest): """BGE-M3 embedding generator (Dense + Sparse) พร้อม CPU fallback และ timeout guard""" if bge_model is None: raise HTTPException(status_code=503, detail="BGE-M3 model not loaded") threshold_mb = float(os.getenv("VRAM_HEADROOM_THRESHOLD_MB", "3000.0")) timeout_sec = float(os.getenv("RETRIEVAL_TIMEOUT_SECONDS", "30.0")) headroom = await asyncio.to_thread(get_vram_headroom) device = "cuda" reason = "headroom-sufficient" if not headroom.query_success: device = "cpu" reason = "gpu-query-failed" elif headroom.available_mb < threshold_mb: device = "cpu" reason = "gpu-headroom-below-threshold" try: if device == "cuda": import torch if torch.cuda.is_available(): bge_model.model.to("cuda") else: device = "cpu" reason = "cuda-not-available" bge_model.model.to("cpu") else: bge_model.model.to("cpu") except Exception as e: logger.warning(f"Failed to move BGE-M3 model to {device}: {e}") device = "cpu" reason = f"device-move-failed: {str(e)}" try: bge_model.model.to("cpu") except Exception: pass logger.info(f"Embedding on device: {device} (reason: {reason})") def run_inference(): output = bge_model.encode([req.text], return_dense=True, return_sparse=True) dense_vector = [float(x) for x in output['dense_vecs'][0]] lexical_dict = output['lexical_weights'][0] indices = [] values = [] for token_id, weight in lexical_dict.items(): indices.append(int(token_id)) values.append(float(weight)) return dense_vector, indices, values try: dense_vector, indices, values = await asyncio.wait_for( asyncio.to_thread(run_inference), timeout=timeout_sec ) return EmbedResponse( dense=dense_vector, sparse={"indices": indices, "values": values}, device=device ) except asyncio.TimeoutError: logger.error(f"Embedding generation timed out after {timeout_sec}s on device {device}") raise HTTPException(status_code=504, detail="Embedding generation timed out") except Exception as e: logger.error(f"Embedding generation failed: {e}") raise HTTPException(status_code=500, detail=f"Embedding generation failed: {str(e)}") @app.post("/rerank", response_model=RerankResponse, dependencies=[Depends(get_api_key)]) async def rerank_chunks(req: RerankRequest): """BGE-Reranker-Large chunk re-ranker พร้อม CPU fallback และ timeout guard""" if reranker is None: raise HTTPException(status_code=503, detail="Reranker model not loaded") if not req.chunks: return RerankResponse(scores=[], ranked_indices=[], device="cpu") threshold_mb = float(os.getenv("VRAM_HEADROOM_THRESHOLD_MB", "3000.0")) timeout_sec = float(os.getenv("RETRIEVAL_TIMEOUT_SECONDS", "30.0")) headroom = await asyncio.to_thread(get_vram_headroom) device = "cuda" reason = "headroom-sufficient" if not headroom.query_success: device = "cpu" reason = "gpu-query-failed" elif headroom.available_mb < threshold_mb: device = "cpu" reason = "gpu-headroom-below-threshold" try: if device == "cuda": import torch if torch.cuda.is_available(): reranker.model.to("cuda") else: device = "cpu" reason = "cuda-not-available" reranker.model.to("cpu") else: reranker.model.to("cpu") except Exception as e: logger.warning(f"Failed to move Reranker model to {device}: {e}") device = "cpu" reason = f"device-move-failed: {str(e)}" try: reranker.model.to("cpu") except Exception: pass logger.info(f"Reranking on device: {device} (reason: {reason})") def run_rerank(): pairs = [[req.query, chunk] for chunk in req.chunks] scores = reranker.compute_score(pairs) if isinstance(scores, float): scores = [scores] else: scores = [float(s) for s in scores] indexed_scores = list(enumerate(scores)) indexed_scores.sort(key=lambda x: x[1], reverse=True) ranked_indices = [idx for idx, _ in indexed_scores] return scores, ranked_indices try: scores, ranked_indices = await asyncio.wait_for( asyncio.to_thread(run_rerank), timeout=timeout_sec ) return RerankResponse( scores=scores, ranked_indices=ranked_indices, device=device ) except asyncio.TimeoutError: logger.error(f"Reranking timed out after {timeout_sec}s on device {device}") raise HTTPException(status_code=504, detail="Reranking timed out") except Exception as e: logger.error(f"Reranking failed: {e}") raise HTTPException(status_code=500, detail=f"Reranking failed: {str(e)}") if __name__ == "__main__": import uvicorn port = int(os.getenv("OCR_PORT", "8765")) uvicorn.run(app, host="0.0.0.0", port=port)