| import os |
| import openai |
| from typing import TYPE_CHECKING, Literal, Optional |
| from langchain_core.language_models.chat_models import BaseChatModel |
|
|
| if TYPE_CHECKING: from histopath.config import HistoPathConfig |
|
|
| SourceType = Literal["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "HuggingFace", "Custom"] |
| ALLOWED_SOURCES: set[str] = set(SourceType.__args__) |
|
|
|
|
| def get_llm( |
| model: str | None = None, |
| temperature: float | None = None, |
| stop_sequences: list[str] | None = None, |
| source: SourceType | None = None, |
| base_url: str | None = None, |
| api_key: str | None = None, |
| config: Optional["HistoPathConfig"] = None, |
| ) -> BaseChatModel: |
| """ |
| Get a language model instance based on the specified model name and source. |
| This function supports models from OpenAI, Azure OpenAI, Anthropic, Ollama, Gemini, Bedrock, and custom model serving. |
| Args: |
| model (str): The model name to use |
| temperature (float): Temperature setting for generation |
| stop_sequences (list): Sequences that will stop generation |
| source (str): Source provider: "OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", or "Custom" |
| If None, will attempt to auto-detect from model name |
| base_url (str): The base URL for custom model serving (e.g., "http://localhost:8000/v1"), default is None |
| api_key (str): The API key for the custom llm |
| config (BiomniConfig): Optional configuration object. If provided, unspecified parameters will use config values |
| """ |
| |
| if config is not None: |
| if model is None: |
| model = config.llm_model |
| if temperature is None: |
| temperature = config.temperature |
| if source is None: |
| source = config.source |
| if base_url is None: |
| base_url = config.base_url |
| if api_key is None: |
| api_key = config.api_key or "EMPTY" |
|
|
| |
| if model is None: |
| model = "claude-3-5-sonnet-20241022" |
| if temperature is None: |
| temperature = 0.7 |
| if api_key is None: |
| api_key = "EMPTY" |
| |
| if source is None: |
| env_source = os.getenv("LLM_SOURCE") |
| if env_source in ALLOWED_SOURCES: |
| source = env_source |
| else: |
| if model[:7] == "claude-": |
| source = "Anthropic" |
| elif model[:7] == "gpt-oss": |
| source = "Ollama" |
| elif model[:4] == "gpt-": |
| source = "OpenAI" |
| elif model.startswith("azure-"): |
| source = "AzureOpenAI" |
| elif model[:7] == "gemini-": |
| source = "Gemini" |
| elif "groq" in model.lower(): |
| source = "Groq" |
| elif base_url is not None: |
| source = "Custom" |
| elif "/" in model or any( |
| name in model.lower() |
| for name in [ |
| "llama", |
| "mistral", |
| "qwen", |
| "gemma", |
| "phi", |
| "dolphin", |
| "orca", |
| "vicuna", |
| "deepseek", |
| ] |
| ): |
| source = "Ollama" |
| elif model.startswith( |
| ("anthropic.claude-", "amazon.titan-", "meta.llama-", "mistral.", "cohere.", "ai21.", "us.") |
| ): |
| source = "Bedrock" |
| else: |
| raise ValueError("Unable to determine model source. Please specify 'source' parameter.") |
|
|
| |
| if source == "OpenAI": |
| try: |
| from langchain_openai import ChatOpenAI |
| except ImportError: |
| raise ImportError( |
| "langchain-openai package is required for OpenAI models. Install with: pip install langchain-openai" |
| ) |
| return ChatOpenAI(model=model, temperature=temperature, stop_sequences=stop_sequences) |
|
|
| elif source == "AzureOpenAI": |
| try: |
| from langchain_openai import AzureChatOpenAI |
| except ImportError: |
| raise ImportError( |
| "langchain-openai package is required for Azure OpenAI models. Install with: pip install langchain-openai" |
| ) |
| API_VERSION = "2024-12-01-preview" |
| model = model.replace("azure-", "") |
| return AzureChatOpenAI( |
| openai_api_key=os.getenv("OPENAI_API_KEY"), |
| azure_endpoint=os.getenv("OPENAI_ENDPOINT"), |
| azure_deployment=model, |
| openai_api_version=API_VERSION, |
| temperature=temperature, |
| ) |
|
|
| elif source == "Anthropic": |
| try: |
| from langchain_anthropic import ChatAnthropic |
| except ImportError: |
| raise ImportError( |
| "langchain-anthropic package is required for Anthropic models. Install with: pip install langchain-anthropic" |
| ) |
| return ChatAnthropic( |
| model=model, |
| temperature=temperature, |
| max_tokens=8192, |
| stop_sequences=stop_sequences, |
| ) |
|
|
| elif source == "Gemini": |
| |
| |
| |
| |
| |
| |
| try: |
| from langchain_openai import ChatOpenAI |
| except ImportError: |
| raise ImportError( |
| "langchain-openai package is required for Gemini models. Install with: pip install langchain-openai" |
| ) |
| return ChatOpenAI( |
| model=model, |
| temperature=temperature, |
| api_key=os.getenv("GEMINI_API_KEY"), |
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
| stop_sequences=stop_sequences, |
| ) |
|
|
| elif source == "Groq": |
| try: |
| from langchain_openai import ChatOpenAI |
| except ImportError: |
| raise ImportError( |
| "langchain-openai package is required for Groq models. Install with: pip install langchain-openai" |
| ) |
| return ChatOpenAI( |
| model=model, |
| temperature=temperature, |
| api_key=os.getenv("GROQ_API_KEY"), |
| base_url="https://api.groq.com/openai/v1", |
| stop_sequences=stop_sequences, |
| ) |
|
|
| elif source == "Ollama": |
| try: |
| from langchain_ollama import ChatOllama |
| except ImportError: |
| raise ImportError( |
| "langchain-ollama package is required for Ollama models. Install with: pip install langchain-ollama" |
| ) |
| return ChatOllama( |
| model=model, |
| temperature=temperature, |
| ) |
| |
| elif source == "Bedrock": |
| try: |
| from langchain_aws import ChatBedrock |
| except ImportError: |
| raise ImportError( |
| "langchain-aws package is required for Bedrock models. Install with: pip install langchain-aws" |
| ) |
| return ChatBedrock( |
| model=model, |
| temperature=temperature, |
| stop_sequences=stop_sequences, |
| region_name=os.getenv("AWS_REGION", "us-east-1"), |
| ) |
| elif source == "HuggingFace": |
| try: |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace |
| except ImportError: |
| raise ImportError( |
| "langchain-huggingface package is required for HuggingFace models. Install with: pip install langchain-huggingface" |
| ) |
| llm = HuggingFaceEndpoint( |
| repo_id="openai/gpt-oss-120b", |
| temperature=temperature, |
| stop_sequences=stop_sequences, |
| huggingfacehub_api_token=os.getenv("HUGGINGFACE_API_KEY") |
| ) |
| return ChatHuggingFace(llm=llm) |
| |
| elif source == "Custom": |
| try: |
| from langchain_openai import ChatOpenAI |
| except ImportError: |
| raise ImportError( |
| "langchain-openai package is required for custom models. Install with: pip install langchain-openai" |
| ) |
| |
| assert base_url is not None, "base_url must be provided for customly served LLMs" |
| llm = ChatOpenAI( |
| model=model, |
| temperature=temperature, |
| max_tokens=8192, |
| stop_sequences=stop_sequences, |
| base_url=base_url, |
| api_key=api_key, |
| ) |
| return llm |
|
|
| else: |
| raise ValueError( |
| f"Invalid source: {source}. Valid options are 'OpenAI', 'AzureOpenAI', 'Anthropic', 'Gemini', 'Groq', 'Bedrock', or 'Ollama'" |
| ) |