| import os |
| import re |
| import json |
| from fastapi import FastAPI, Header, HTTPException, Request |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel |
| from llmlingua import PromptCompressor |
|
|
| |
| os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") |
| os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1") |
|
|
| |
| FALLBACK_MODEL = "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" |
| MODEL_NAME = os.environ.get("LLMLINGUA_MODEL", FALLBACK_MODEL) |
| API_KEY = os.environ.get("LLMLINGUA_API_KEY") |
|
|
| |
| SERVICE_NAME = os.environ.get("SERVICE_NAME", "llmlingua-gpts") |
| SERVICE_OWNER = os.environ.get("SERVICE_OWNER", "clancylin") |
| PRIVACY_EFFECTIVE = os.environ.get("PRIVACY_EFFECTIVE", "2025-11-02") |
|
|
| app = FastAPI(title="LLMLingua Wrapper", version="1.0.0") |
|
|
| def _build_compressor(model_name: str) -> PromptCompressor: |
| return PromptCompressor( |
| model_name=model_name, |
| use_llmlingua2=True, |
| device_map="cpu", |
| model_config={"low_cpu_mem_usage": True}, |
| ) |
|
|
| |
| _loaded_model = MODEL_NAME |
| try: |
| compressor = _build_compressor(MODEL_NAME) |
| except Exception: |
| _loaded_model = FALLBACK_MODEL |
| compressor = _build_compressor(FALLBACK_MODEL) |
|
|
| |
| class CompressOut(BaseModel): |
| compressed_text: str |
| origin_tokens: int | None = None |
| compressed_tokens: int | None = None |
| ratio: str | None = None |
| rate_used: float | None = None |
|
|
| |
| def verify(x_api_key: str | None = None): |
| if API_KEY and x_api_key != API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
| |
| def _coerce_numbers(d: dict) -> dict: |
| if "rate" in d and isinstance(d["rate"], str): |
| try: |
| d["rate"] = float(d["rate"]) |
| except: |
| pass |
| if "target_tokens" in d and isinstance(d["target_tokens"], str): |
| try: |
| d["target_tokens"] = int(d["target_tokens"]) |
| except: |
| pass |
| if "force_tokens" in d and isinstance(d["force_tokens"], str): |
| try: |
| d["force_tokens"] = json.loads(d["force_tokens"]) |
| except: |
| pass |
| return d |
|
|
| async def _read_any_body(request: Request) -> dict: |
| ct = (request.headers.get("content-type") or "").lower() |
| if "application/json" in ct: |
| raw = await request.body() |
| if not raw: |
| return {} |
| try: |
| return _coerce_numbers(json.loads(raw)) |
| except Exception: |
| |
| return {"text": raw.decode("utf-8", "ignore")} |
| if "text/plain" in ct or not ct: |
| return {"text": (await request.body()).decode("utf-8", "ignore")} |
| if "multipart/form-data" in ct or "application/x-www-form-urlencoded" in ct: |
| form = await request.form() |
| data = {k: v for k, v in form.items()} |
| return _coerce_numbers(data) |
| return {"text": (await request.body()).decode("utf-8", "ignore")} |
|
|
| |
| def _auto_rate(text: str, target_tokens: int | None) -> float: |
| n = len(compressor.tokenizer.tokenize(text)) |
| has_code = ("```" in text) or (re.search(r"[{}\[\]]", text) is not None) |
| if target_tokens: |
| return float(min(0.95, max(0.1, target_tokens / max(1, n)))) |
| if has_code: |
| return 0.7 if n >= 1200 else 0.6 |
| if n >= 2000: |
| return 0.4 |
| if n >= 1200: |
| return 0.5 |
| return 0.6 |
|
|
| |
| @app.get("/") |
| def root(): |
| return { |
| "status": "ok", |
| "service": SERVICE_NAME, |
| "owner": SERVICE_OWNER, |
| "requested_model": MODEL_NAME, |
| "loaded_model": _loaded_model, |
| "endpoints": ["/compress", "/healthz", "/privacy"], |
| } |
|
|
| @app.get("/healthz") |
| def healthz(): |
| return {"ok": True} |
|
|
| @app.post("/compress", response_model=CompressOut) |
| async def compress(request: Request, x_api_key: str | None = Header(default=None)): |
| verify(x_api_key) |
|
|
| body = await _read_any_body(request) |
| text = body.get("text") |
| if not isinstance(text, str) or not text.strip(): |
| raise HTTPException(status_code=422, detail="`text` is required (JSON, text/plain, or form-data).") |
|
|
| rate = body.get("rate", None) |
| target_tokens = body.get("target_tokens", None) |
| force_tokens = body.get("force_tokens", None) |
|
|
| kw = {} |
| if isinstance(target_tokens, int) and target_tokens > 0: |
| kw["target_token"] = target_tokens |
| if isinstance(force_tokens, list): |
| kw["force_tokens"] = force_tokens |
|
|
| rate_used = float(rate) if rate is not None else _auto_rate(text, target_tokens if isinstance(target_tokens, int) else None) |
|
|
| out = compressor.compress_prompt(text, rate=rate_used, **kw) |
| raw = out.get("compressed_prompt", "") or out.get("compressed_text", "") |
|
|
| |
| try: |
| toks = compressor.tokenizer.tokenize(raw) |
| comp_text = compressor.tokenizer.convert_tokens_to_string(toks) |
| comp_text = re.sub(r"[^\S\r\n]+([,.;:!?])", r"\1", comp_text).strip() |
| except Exception: |
| comp_text = re.sub(r"[^\S\r\n]+", " ", raw or "").strip() |
|
|
| origin_tokens = len(compressor.tokenizer.tokenize(text)) or 1 |
| compressed_tokens = len(compressor.tokenizer.tokenize(comp_text)) |
| ratio = f"{compressed_tokens / origin_tokens:.2f}x" |
|
|
| return { |
| "compressed_text": comp_text, |
| "origin_tokens": origin_tokens, |
| "compressed_tokens": compressed_tokens, |
| "ratio": ratio, |
| "rate_used": rate_used, |
| } |
|
|
| @app.get("/privacy", response_class=HTMLResponse) |
| def privacy(): |
| return """ |
| <!doctype html> |
| <html><head><meta charset="utf-8"><title>Privacy Policy</title></head> |
| <body> |
| <h1>Privacy Policy</h1> |
| <p>This service compresses text using LLMLingua models hosted on Hugging Face Spaces.</p> |
| <h2>What we process</h2> |
| <ul> |
| <li>Text you send in the <code>text</code> field and optional parameters (<code>rate</code>, etc.).</li> |
| <li>We do not use cookies or track users on this endpoint.</li> |
| </ul> |
| <h2>How we use data</h2> |
| <ul> |
| <li>Inputs are used solely to compute the compressed result and return it to the caller.</li> |
| <li>Application logs may include timestamps, path (<code>/compress</code>), and error traces for reliability.</li> |
| </ul> |
| <h2>Retention</h2> |
| <ul> |
| <li>We do not persist request bodies after processing. Platform-level logs may be retained by Hugging Face per their policies.</li> |
| </ul> |
| <h2>Third parties</h2> |
| <ul> |
| <li>Hosted on Hugging Face Spaces; their policies apply to infrastructure-level logging and telemetry.</li> |
| </ul> |
| <h2>Security</h2> |
| <ul> |
| <li>HTTPS enforced by the hosting platform.</li> |
| </ul> |
| <h2>Contact</h2> |
| <ul> |
| <li>For inquiries, please use the <strong>Discussions</strong> tab on the project’s |
| <a href="https://huggingface.co/spaces/ClancyLin/llmlingua">Hugging Face Space</a>.</li> |
| </ul> |
| </body></html> |
| """ |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False) |
|
|