llmlingua / app.py
ClancyLin's picture
Update app.py
de9a18d verified
import os
import re
import json
from fastapi import FastAPI, Header, HTTPException, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from llmlingua import PromptCompressor
# ---- Force CPU (avoid CUDA on CPU-only hosts)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "")
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")
# ---- Config via env (tweak without code changes)
FALLBACK_MODEL = "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank"
MODEL_NAME = os.environ.get("LLMLINGUA_MODEL", FALLBACK_MODEL)
API_KEY = os.environ.get("LLMLINGUA_API_KEY") # optional
# For /privacy (edit by env if you like)
SERVICE_NAME = os.environ.get("SERVICE_NAME", "llmlingua-gpts")
SERVICE_OWNER = os.environ.get("SERVICE_OWNER", "clancylin")
PRIVACY_EFFECTIVE = os.environ.get("PRIVACY_EFFECTIVE", "2025-11-02")
app = FastAPI(title="LLMLingua Wrapper", version="1.0.0")
def _build_compressor(model_name: str) -> PromptCompressor:
return PromptCompressor(
model_name=model_name,
use_llmlingua2=True,
device_map="cpu",
model_config={"low_cpu_mem_usage": True},
)
# Try desired model; fall back to a public multilingual one if it fails
_loaded_model = MODEL_NAME
try:
compressor = _build_compressor(MODEL_NAME)
except Exception:
_loaded_model = FALLBACK_MODEL
compressor = _build_compressor(FALLBACK_MODEL)
# ---- Schemas
class CompressOut(BaseModel):
compressed_text: str
origin_tokens: int | None = None
compressed_tokens: int | None = None
ratio: str | None = None
rate_used: float | None = None
# ---- Optional API key check
def verify(x_api_key: str | None = None):
if API_KEY and x_api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
# ---- Body parsing: accept JSON / text / form-data / urlencoded
def _coerce_numbers(d: dict) -> dict:
if "rate" in d and isinstance(d["rate"], str):
try:
d["rate"] = float(d["rate"])
except:
pass
if "target_tokens" in d and isinstance(d["target_tokens"], str):
try:
d["target_tokens"] = int(d["target_tokens"])
except:
pass
if "force_tokens" in d and isinstance(d["force_tokens"], str):
try:
d["force_tokens"] = json.loads(d["force_tokens"])
except:
pass
return d
async def _read_any_body(request: Request) -> dict:
ct = (request.headers.get("content-type") or "").lower()
if "application/json" in ct:
raw = await request.body()
if not raw:
return {}
try:
return _coerce_numbers(json.loads(raw))
except Exception:
# Fallback: treat whole body as plain text
return {"text": raw.decode("utf-8", "ignore")}
if "text/plain" in ct or not ct:
return {"text": (await request.body()).decode("utf-8", "ignore")}
if "multipart/form-data" in ct or "application/x-www-form-urlencoded" in ct:
form = await request.form()
data = {k: v for k, v in form.items()}
return _coerce_numbers(data)
return {"text": (await request.body()).decode("utf-8", "ignore")}
# ---- Heuristic rate when not provided
def _auto_rate(text: str, target_tokens: int | None) -> float:
n = len(compressor.tokenizer.tokenize(text))
has_code = ("```" in text) or (re.search(r"[{}\[\]]", text) is not None)
if target_tokens:
return float(min(0.95, max(0.1, target_tokens / max(1, n))))
if has_code:
return 0.7 if n >= 1200 else 0.6
if n >= 2000:
return 0.4
if n >= 1200:
return 0.5
return 0.6
# ---- Routes
@app.get("/")
def root():
return {
"status": "ok",
"service": SERVICE_NAME,
"owner": SERVICE_OWNER,
"requested_model": MODEL_NAME,
"loaded_model": _loaded_model,
"endpoints": ["/compress", "/healthz", "/privacy"],
}
@app.get("/healthz")
def healthz():
return {"ok": True}
@app.post("/compress", response_model=CompressOut)
async def compress(request: Request, x_api_key: str | None = Header(default=None)):
verify(x_api_key) # remove if you don't use API keys
body = await _read_any_body(request)
text = body.get("text")
if not isinstance(text, str) or not text.strip():
raise HTTPException(status_code=422, detail="`text` is required (JSON, text/plain, or form-data).")
rate = body.get("rate", None)
target_tokens = body.get("target_tokens", None)
force_tokens = body.get("force_tokens", None)
kw = {}
if isinstance(target_tokens, int) and target_tokens > 0:
kw["target_token"] = target_tokens
if isinstance(force_tokens, list):
kw["force_tokens"] = force_tokens
rate_used = float(rate) if rate is not None else _auto_rate(text, target_tokens if isinstance(target_tokens, int) else None)
out = compressor.compress_prompt(text, rate=rate_used, **kw)
raw = out.get("compressed_prompt", "") or out.get("compressed_text", "")
# Detokenize to avoid "char + space" artifacts
try:
toks = compressor.tokenizer.tokenize(raw)
comp_text = compressor.tokenizer.convert_tokens_to_string(toks)
comp_text = re.sub(r"[^\S\r\n]+([,.;:!?])", r"\1", comp_text).strip()
except Exception:
comp_text = re.sub(r"[^\S\r\n]+", " ", raw or "").strip()
origin_tokens = len(compressor.tokenizer.tokenize(text)) or 1
compressed_tokens = len(compressor.tokenizer.tokenize(comp_text))
ratio = f"{compressed_tokens / origin_tokens:.2f}x"
return {
"compressed_text": comp_text,
"origin_tokens": origin_tokens,
"compressed_tokens": compressed_tokens,
"ratio": ratio,
"rate_used": rate_used,
}
@app.get("/privacy", response_class=HTMLResponse)
def privacy():
return """
<!doctype html>
<html><head><meta charset="utf-8"><title>Privacy Policy</title></head>
<body>
<h1>Privacy Policy</h1>
<p>This service compresses text using LLMLingua models hosted on Hugging Face Spaces.</p>
<h2>What we process</h2>
<ul>
<li>Text you send in the <code>text</code> field and optional parameters (<code>rate</code>, etc.).</li>
<li>We do not use cookies or track users on this endpoint.</li>
</ul>
<h2>How we use data</h2>
<ul>
<li>Inputs are used solely to compute the compressed result and return it to the caller.</li>
<li>Application logs may include timestamps, path (<code>/compress</code>), and error traces for reliability.</li>
</ul>
<h2>Retention</h2>
<ul>
<li>We do not persist request bodies after processing. Platform-level logs may be retained by Hugging Face per their policies.</li>
</ul>
<h2>Third parties</h2>
<ul>
<li>Hosted on Hugging Face Spaces; their policies apply to infrastructure-level logging and telemetry.</li>
</ul>
<h2>Security</h2>
<ul>
<li>HTTPS enforced by the hosting platform.</li>
</ul>
<h2>Contact</h2>
<ul>
<li>For inquiries, please use the <strong>Discussions</strong> tab on the project’s
<a href="https://huggingface.co/spaces/ClancyLin/llmlingua">Hugging Face Space</a>.</li>
</ul>
</body></html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "7860")), reload=False)