| | --- |
| | license: mit |
| | language: |
| | - en |
| | library_name: transformers |
| | tags: |
| | - lora |
| | - peft |
| | - reinforcement-learning |
| | - domain-adaptation |
| | - sentence-embeddings |
| | - curriculum-learning |
| | - multi-task-learning |
| | - rag |
| | - information-retrieval |
| | - cross-domain |
| | - sentence-transformers |
| | base_model: |
| | - sentence-transformers/all-MiniLM-L6-v2 |
| | - EphAsad/FireDevourerEmbedder-RL-v3.6 |
| | pipeline_tag: sentence-similarity |
| | datasets: |
| | - sentence-transformers/stsb |
| | - nyu-mll/multi_nli |
| | - quora |
| | - google-research-datasets/paws |
| | - nyu-mll/glue |
| | - GBaker/MedQA-USMLE-4-options-hf |
| | - lex_glue |
| | - gbharti/finance-alpaca |
| | - scientific_papers |
| | model-index: |
| | - name: DomainEmbedder-v2.6 |
| | results: |
| | - task: |
| | type: domain-classification |
| | name: Domain Classification |
| | metrics: |
| | - type: accuracy |
| | value: 0.925 |
| | name: Training Accuracy |
| | - type: accuracy |
| | value: 0.56 |
| | name: Stress-Test Accuracy |
| | --- |
| | |
| | # DomainEmbedder-v2.6 |
| |
|
| | > **High-Information-Density Embeddings for Cross-Domain RAG and Retrieval** |
| |
|
| | DomainEmbedder-v2.6 produces **information-dense embeddings** optimized for retrieval-augmented generation (RAG) and cross-domain similarity matching. It combines a multi-task base embedder with domain-adaptive LoRA routing. |
| |
|
| | ## What This Model Does |
| |
|
| | | Component | Description | |
| | |-----------|-------------| |
| | | **Base Embedder** | FireDevourerEmbedder-RL-v3.6 trained on 5 NLP tasks with RL-based task weighting | |
| | | **Domain LoRAs** | 5 specialized adapters (Medical, Legal, Code, Finance, Scientific) | |
| | | **RL Policy** | Automatically selects the optimal domain adapter for any input | |
| |
|
| | **Why this matters for RAG/Retrieval:** |
| | - Embeddings encode multiple facets of meaning (similarity, entailment, paraphrase, questions) |
| | - Domain routing provides context-appropriate representations |
| | - Results in more precise retrieval across diverse content types |
| |
|
| | ## Key Innovation: Dual RL Architecture |
| |
|
| | | Stage | RL Application | Purpose | |
| | |-------|---------------|---------| |
| | | Base Model Training | Task Weight Policy | Dynamically balance 5 NLP objectives during training | |
| | | Domain Extension | Adapter Selection Policy | Route to appropriate domain LoRA at inference | |
| |
|
| | This dual RL approach is novel: **RL at training time AND inference time**. |
| |
|
| | ## Quick Start |
| |
|
| | ### Installation |
| |
|
| | ```bash |
| | pip install torch transformers peft |
| | ``` |
| |
|
| | ### Loading the Model |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoTokenizer, AutoModel |
| | from peft import PeftModel |
| | |
| | # Device setup |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | |
| | # Load tokenizer |
| | tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
| | |
| | # Define the base embedder architecture |
| | class FireDevourerEmbedder(nn.Module): |
| | def __init__(self, base_model_name='sentence-transformers/all-MiniLM-L6-v2'): |
| | super().__init__() |
| | self.encoder = AutoModel.from_pretrained(base_model_name) |
| | self.hidden_size = 384 |
| | |
| | # Task heads |
| | self.sts_head = nn.Sequential(nn.Linear(384, 1), nn.Sigmoid()) |
| | self.nli_head = nn.Linear(384, 3) |
| | self.qqp_head = nn.Linear(384, 2) |
| | self.paws_head = nn.Linear(384, 2) |
| | self.domain_head = nn.Linear(384, 5) |
| | |
| | def mean_pool(self, token_embeddings, attention_mask): |
| | mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| | return torch.sum(token_embeddings * mask, 1) / torch.clamp(mask.sum(1), min=1e-9) |
| | |
| | def forward(self, input_ids, attention_mask, task='encode'): |
| | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| | embedding = self.mean_pool(outputs.last_hidden_state, attention_mask) |
| | |
| | if task == 'encode': |
| | return embedding |
| | elif task == 'domain': |
| | return self.domain_head(embedding) |
| | # Add other tasks as needed |
| | |
| | # Define RL Policy Network |
| | class RLPolicyNetwork(nn.Module): |
| | def __init__(self, input_dim=384, hidden_dim=128, num_actions=5): |
| | super().__init__() |
| | self.network = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, hidden_dim), |
| | nn.ReLU() |
| | ) |
| | self.policy_head = nn.Linear(hidden_dim, num_actions) |
| | self.value_head = nn.Linear(hidden_dim, 1) |
| | |
| | def forward(self, x): |
| | features = self.network(x) |
| | policy = torch.softmax(self.policy_head(features), dim=-1) |
| | value = self.value_head(features) |
| | return policy, value |
| | |
| | # Load model |
| | model_dir = "path/to/DomainEmbedder-v2.6" |
| | |
| | # 1. Load base model with checkpoint |
| | base_model = FireDevourerEmbedder() |
| | checkpoint = torch.load(f"{model_dir}/FireDevourerEmbedder-RL-v3.6.pt", map_location=device) |
| | base_model.load_state_dict(checkpoint['model_state_dict'], strict=False) |
| | base_model.to(device) |
| | base_model.eval() |
| | |
| | # 2. Load RL policy |
| | rl_policy = RLPolicyNetwork() |
| | rl_checkpoint = torch.load(f"{model_dir}/rl_policy.pt", map_location=device) |
| | rl_policy.load_state_dict(rl_checkpoint['policy_state_dict']) |
| | rl_policy.to(device) |
| | rl_policy.eval() |
| | |
| | # 3. Load LoRA adapters (example: medical) |
| | from peft import PeftModel |
| | lora_model = PeftModel.from_pretrained( |
| | base_model.encoder, |
| | f"{model_dir}/medical_lora" |
| | ) |
| | ``` |
| |
|
| | ### Computing Embeddings with Domain Selection |
| |
|
| | ```python |
| | def get_domain_embedding(text, base_model, rl_policy, lora_models, tokenizer, device): |
| | """Get domain-aware embedding for input text.""" |
| | # Tokenize |
| | inputs = tokenizer(text, return_tensors='pt', padding=True, |
| | truncation=True, max_length=512).to(device) |
| | |
| | # Get base embedding |
| | with torch.no_grad(): |
| | base_emb = base_model(inputs['input_ids'], inputs['attention_mask'], task='encode') |
| | |
| | # Get domain selection from RL policy |
| | policy_probs, _ = rl_policy(base_emb) |
| | domain_idx = torch.argmax(policy_probs, dim=-1).item() |
| | |
| | domains = ['medical', 'legal', 'code', 'finance', 'scientific'] |
| | selected_domain = domains[domain_idx] |
| | confidence = policy_probs[0, domain_idx].item() |
| | |
| | return { |
| | 'embedding': base_emb, |
| | 'domain': selected_domain, |
| | 'confidence': confidence, |
| | 'all_probs': policy_probs[0].cpu().numpy() |
| | } |
| | |
| | # Example usage |
| | result = get_domain_embedding( |
| | "What are the symptoms of diabetes?", |
| | base_model, rl_policy, None, tokenizer, device |
| | ) |
| | print(f"Domain: {result['domain']} (confidence: {result['confidence']:.2%})") |
| | ``` |
| |
|
| | ## Architecture |
| |
|
| | ``` |
| | Input Text |
| | β |
| | βΌ |
| | ββββββββββββββββββββββββββββββββββββββββββββββ |
| | β MiniLM-L6-v2 Encoder (FROZEN) β |
| | β + Optional LoRA Adapter (domain-specific) β |
| | β 384-dimensional output β |
| | ββββββββββββββββββββββββββββββββββββββββββββββ |
| | β |
| | ββββββββββββββββββββββββββββββββββββββββββββ |
| | β β |
| | βΌ βΌ |
| | βββββββββββββββββββ ββββββββββββββββββββ |
| | β Base Embedding β β RL Policy Net β |
| | β (384-dim) β β (66K params) β |
| | βββββββββββββββββββ ββββββββββββββββββββ |
| | β |
| | βΌ |
| | Domain Selection |
| | [Medical, Legal, Code, |
| | Finance, Scientific] |
| | β |
| | βΌ |
| | Load corresponding LoRA adapter |
| | β |
| | βΌ |
| | Domain-Adapted Embedding |
| | ``` |
| |
|
| | ### Component Details |
| |
|
| | | Component | Specification | |
| | |-----------|---------------| |
| | | Base Encoder | MiniLM-L6-v2 (22M params) | |
| | | Embedding Dim | 384 | |
| | | LoRA Rank | 16 | |
| | | LoRA Alpha | 32 | |
| | | LoRA Target | Query, Value projections | |
| | | LoRA Params | 147,456 per adapter (0.645%) | |
| | | RL Policy | 66,566 params | |
| | | Domains | Medical, Legal, Code, Finance, Scientific | |
| |
|
| | ## Performance |
| |
|
| | ### Base Model: Multi-Task Embedding Quality |
| |
|
| | The base FireDevourerEmbedder achieves **0.71 average** across 5 distinct NLP tasks: |
| |
|
| | | Task | Dataset | Score | What It Measures | |
| | |------|---------|-------|------------------| |
| | | Question Similarity | QQP | 0.8636 | Intent matching | |
| | | Paraphrase Detection | PAWS | 0.8459 | Adversarial robustness | |
| | | Paraphrase Detection | MRPC | 0.7744 | News domain paraphrase | |
| | | NLI | MultiNLI | 0.7465 | Logical relationships | |
| | | Semantic Similarity | STS-B | 0.3366 | Fine-grained similarity | |
| | | **Average** | | **0.7134** | **Cross-task capability** | |
| |
|
| | **Philosophy**: Individual task scores are traded for cross-domain information density. This makes embeddings more versatile for RAG and retrieval across diverse content. |
| |
|
| | ### Domain Routing Accuracy |
| |
|
| | **Training Results (In-Distribution)** |
| |
|
| | | Metric | Value | |
| | |--------|-------| |
| | | Domain Accuracy | 92.5% | |
| | | Average Reward | 1.527 | |
| | | Training Steps | 5,000 | |
| |
|
| | **Stress-Test Benchmark (Semantically Similar Cross-Domain Phrases)** |
| |
|
| | The benchmark intentionally uses complex, semantically similar phrases across domains to test robustness: |
| |
|
| | | Metric | DomainEmbedder (RL+LoRA) | Base Model | Improvement | |
| | |--------|--------------------------|------------|-------------| |
| | | Domain Accuracy | 56.0% | 20.4% | **+35.6%** | |
| | | Avg Confidence | 28.5% | 77.6% | More calibrated | |
| |
|
| | ### Per-Domain Breakdown |
| |
|
| | | Domain | DomainEmbedder | Base Model | Note | |
| | |--------|----------------|------------|------| |
| | | Finance | 78.0% | 0.0% | +78.0% | |
| | | Medical | 73.0% | 0.0% | +73.0% | |
| | | Legal | 53.0% | 15.0% | +38.0% | |
| | | Scientific | 48.0% | 1.0% | +47.0% | |
| | | Code | 28.0% | 86.0% | Base over-predicted code | |
| |
|
| | **Key Insight**: The base model had an 86% "code" prediction bias with high confidence. The RL+LoRA system corrects this by providing balanced, calibrated domain distribution. |
| |
|
| | ## Training Details |
| |
|
| | ### Domain Training Data |
| |
|
| | | Domain | Samples | Sources | |
| | |--------|---------|---------| |
| | | Medical | 40,000 | MedQA-USMLE, MedQuAD, PubMedQA, Medical Meadow, ChatDoctor | |
| | | Legal | 40,000 | EUR-LEX, CaseHold, ECTHR-A, ECTHR-B | |
| | | Code | 40,000 | Code Alpaca, MBPP, Code Contests, Python Instructions | |
| | | Finance | 40,000 | Finance Alpaca, FinGPT-FiQA, Financial QA | |
| | | Scientific | 40,000 | arXiv, PubMed (87.3% real + 12.7% augmented) | |
| | | **Total** | **200,000** | | |
| |
|
| | ### LoRA Training Configuration |
| |
|
| | | Parameter | Value | |
| | |-----------|-------| |
| | | Epochs | 3 per domain | |
| | | Batch Size | 32 | |
| | | Learning Rate | 2e-4 | |
| | | Loss | Contrastive (InfoNCE-style) | |
| | | Trainable Params | 147,456 (0.645% of base) | |
| | | Warmup Steps | 500 | |
| | | Max Gradient Norm | 1.0 | |
| |
|
| | ### RL Training (Supervised A2C) |
| |
|
| | | Parameter | Value | |
| | |-----------|-------| |
| | | Algorithm | Actor-Critic (A2C) | |
| | | Total Steps | 5,000 | |
| | | Episodes per Step | 5 | |
| | | Gamma (discount) | 0.99 | |
| | | Entropy Coef | 0.1 (high exploration) | |
| | | Value Coef | 0.5 | |
| | | Correctness Bonus | +1.0 | |
| | | Correctness Penalty | -0.5 | |
| | | Baseline Decay | 0.99 | |
| |
|
| | ### Curriculum Learning Phases |
| |
|
| | | Phase | Steps | Data | Accuracy | |
| | |-------|-------|------|----------| |
| | | 1 (Easy) | 0-1,500 | Clear domain examples (10K) | 68.8% β 87.5% | |
| | | 2 (Moderate) | 1,500-3,500 | Easy + ambiguous (20K) | 87.5% β 89.3% | |
| | | 3 (Hard) | 3,500-5,000 | All data incl. hybrid (28K) | 89.3% β 92.5% | |
| |
|
| | ### Training Progress |
| |
|
| | | Version | Step | Accuracy | Reward | |
| | |---------|------|----------|--------| |
| | | v2.1 | 500 | 68.8% | 1.100 | |
| | | v2.2 | 1,000 | 80.1% | 1.336 | |
| | | v2.3 | 1,500 | 87.5% | 1.454 | |
| | | v2.4 | 2,000 | 88.9% | 1.480 | |
| | | v2.5 | 3,000 | 89.3% | 1.507 | |
| | | **v2.6** | **4,000** | **92.5%** | **1.527** | |
| |
|
| | ## Package Contents |
| |
|
| | ``` |
| | DomainEmbedder-v2.6/ |
| | βββ FireDevourerEmbedder-RL-v3.6.pt # Base model checkpoint (86.7 MB) |
| | βββ rl_policy.pt # Trained RL policy (0.27 MB) |
| | βββ metadata.json # Training metadata |
| | βββ README.md # This file |
| | βββ medical_lora/ # Medical domain adapter (0.6 MB) |
| | β βββ adapter_config.json |
| | β βββ adapter_model.safetensors |
| | βββ legal_lora/ # Legal domain adapter (0.6 MB) |
| | βββ code_lora/ # Code domain adapter (0.6 MB) |
| | βββ finance_lora/ # Finance domain adapter (0.6 MB) |
| | βββ scientific_lora/ # Scientific domain adapter (0.6 MB) |
| | ``` |
| |
|
| | **Total Size**: ~90 MB (self-contained) |
| |
|
| | ## Intended Use |
| |
|
| | ### Best Use Cases |
| |
|
| | - **RAG Systems**: Domain-aware retrieval for multi-domain knowledge bases |
| | - **Cross-Domain Search**: Finding similar content across Medical, Legal, Code, Finance, Scientific domains |
| | - **Document Classification**: Automatic domain routing for document processing pipelines |
| | - **Semantic Similarity**: Information-dense embeddings for precise matching |
| | - **Multi-Domain Chatbots**: Context-appropriate responses based on detected domain |
| |
|
| | ### Limitations |
| |
|
| | - **English Only**: Trained exclusively on English data |
| | - **Max Length**: 512 tokens maximum input length |
| | - **Domain Coverage**: 5 domains only (Medical, Legal, Code, Finance, Scientific) |
| | - **Stress-Test Accuracy**: 56% on semantically similar cross-domain queries |
| | - **STS-B Trade-off**: Lower fine-grained similarity (0.34) for broader task coverage |
| |
|
| | ## Citation |
| |
|
| | ```bibtex |
| | @misc{domainembedder2025, |
| | author = {Asad, Zain}, |
| | title = {DomainEmbedder: Domain-Adaptive Embeddings with Dual RL and LoRA}, |
| | year = {2025}, |
| | publisher = {Hugging Face}, |
| | note = {Multi-task base embedder with RL-based task weighting + domain-specific LoRA adapters with curriculum learning} |
| | } |
| | ``` |
| |
|
| | ## Author |
| |
|
| | **Zain Asad** |
| |
|
| | ## License |
| |
|
| | MIT License |