Custom LLM
In Conversational AI Engine interaction scenarios, your use case may require a custom large language model (Custom LLM). This document explains how to integrate a custom LLM into Agora's Conversational AI Engine.
Understand the tech
Agora's Conversational AI Engine interacts with LLM services using the OpenAI API protocol. To integrate a custom LLM, you need to provide an HTTP service compatible with the OpenAI API, capable of handling requests and responses in the OpenAI API format.
This approach enables you to implement additional custom functionalities, including but not limited to:
- Retrieval-Augmented Generation (RAG): Allows the model to retrieve information from a specific knowledge base.
- Multimodal Capabilities: Enables the model to generate output in both text and audio formats.
- Tool Invocation: Allows the model to call external tools.
- Function Calling: Enables the model to return structured data in the form of function calls.
Prerequisites
Before you begin, ensure that you have:
- Implemented the basic logic for interacting with a Conversational AI agent by following the REST Quickstart.
- Set up access to a custom LLM service.
- Prepared a vector database or retrieval system if using Retrieval-Augmented Generation (RAG).
Implementation
Take the following steps to integrate your custom LLM into Agora's Conversational AI Engine.
Create an OpenAI API-compatible service
To integrate successfully with Agora's Conversational AI Engine, your custom LLM service must provide an interface compatible with the OpenAI Chat Completions API. The key requirements include:
- Endpoint: A request-handling endpoint, such as
https://your-custom-llm-service/chat/completions
. - Request format: Must accept request parameters adhering to the OpenAI API protocol.
- Response format: Should return OpenAI API-compatible responses and support the Server-Sent Events (SSE) standard for streaming.
The following example demonstrates how to implement an OpenAI API-compliant interface:
- Python
- Go
class TextContent(BaseModel): type: str = "text" text: strclass ImageContent(BaseModel): type: str = "image" image_url: HttpUrlclass AudioContent(BaseModel): type: str = "input_audio" input_audio: Dict[str, str]class ToolFunction(BaseModel): name: str description: Optional[str] parameters: Optional[Dict] strict: bool = Falseclass Tool(BaseModel): type: str = "function" function: ToolFunctionclass ToolChoice(BaseModel): type: str = "function" function: Optional[Dict]class ResponseFormat(BaseModel): type: str = "json_schema" json_schema: Optional[Dict[str, str]]class SystemMessage(BaseModel): role: str = "system" content: Union[str, List[str]]class UserMessage(BaseModel): role: str = "user" content: Union[str, List[Union[TextContent, ImageContent, AudioContent]]]class AssistantMessage(BaseModel): role: str = "assistant" content: Union[str, List[TextContent]] = None audio: Optional[Dict[str, str]] = None tool_calls: Optional[List[Dict]] = Noneclass ToolMessage(BaseModel): role: str = "tool" content: Union[str, List[str]] tool_call_id: str# Define the complete request formatclass ChatCompletionRequest(BaseModel): context: Optional[Dict] = None # Context information model: Optional[str] = None # Model name being used messages: List[Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]] # List of messages response_format: Optional[ResponseFormat] = None # Response format modalities: List[str] = ["text"] # Default modality is text audio: Optional[Dict[str, str]] = None # Assistant's audio response tools: Optional[List[Tool]] = None # List of tools tool_choice: Optional[Union[str, ToolChoice]] = "auto" # Tool selection parallel_tool_calls: bool = True # Whether to call tools in parallel stream: bool = True # Default to streaming response stream_options: Optional[Dict] = None # Streaming options@app.post("/chat/completions")async def create_chat_completion(request: ChatCompletionRequest): try: logger.info(f"Received request: {request.model_dump_json()}") client = AsyncOpenAI(api_key=os.getenv("YOUR_LLM_API_KEY")) response = await client.chat.completions.create( model=request.model, messages=request.messages, # Directly use request messages tool_choice=( request.tool_choice if request.tools and request.tool_choice else None ), tools=request.tools if request.tools else None, modalities=request.modalities, audio=request.audio, response_format=request.response_format, stream=request.stream, stream_options=request.stream_options, ) if not request.stream: raise HTTPException( status_code=400, detail="chat completions require streaming" ) async def generate(): try: async for chunk in response: logger.debug(f"Received chunk: {chunk}") yield f"data: {json.dumps(chunk.to_dict())}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: logger.info("Request was cancelled") raise return StreamingResponse(generate(), media_type="text/event-stream") except asyncio.CancelledError: logger.info("Request was cancelled") raise HTTPException(status_code=499, detail="Request was cancelled") except Exception as e: traceback_str = "".join(traceback.format_tb(e.__traceback__)) error_message = f"{str(e)}\n{traceback_str}" logger.error(error_message) raise HTTPException(status_code=500, detail=error_message)
type ( AudioContent struct { InputAudio map[string]string `json:"input_audio"` Type string `json:"type"` } // Complete request format ChatCompletionRequest struct { // Assistant's audio reply Audio map[string]string `json:"audio,omitempty"` // Context information Context map[string]any `json:"context,omitempty"` // Message list Messages []Message `json:"messages"` // Default uses text modality Modalities []string `json:"modalities"` // Model name to use Model string `json:"model,omitempty"` // Whether to call tools in parallel ParallelToolCalls bool `json:"parallel_tool_calls"` // Response format ResponseFormat *ResponseFormat `json:"response_format,omitempty"` // Whether to use streaming response Stream bool `json:"stream"` // Streaming options StreamOptions map[string]any `json:"stream_options,omitempty"` // Tool selection strategy, default value is "auto" ToolChoice any `json:"tool_choice,omitempty"` // Tool list Tools []Tool `json:"tools,omitempty"` } ImageContent struct { ImageURL string `json:"image_url"` Type string `json:"type"` } Message struct { Audio map[string]string `json:"audio,omitempty"` Content any `json:"content"` Role string `json:"role"` ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls []map[string]any `json:"tool_calls,omitempty"` } ResponseFormat struct { JSONSchema map[string]string `json:"json_schema,omitempty"` Type string `json:"type"` } TextContent struct { Text string `json:"text"` Type string `json:"type"` } Tool struct { Function ToolFunction `json:"function"` Type string `json:"type"` } ToolChoice struct { Function map[string]any `json:"function,omitempty"` Type string `json:"type"` } ToolFunction struct { Description string `json:"description,omitempty"` Name string `json:"name"` Parameters map[string]any `json:"parameters,omitempty"` Strict bool `json:"strict"` })var waitingMessages = []string{ "Just a moment, I'm thinking...", "Let me think about that for a second...", "Good question, let me find out...",}// Chat Completion Servertype Server struct { client *openai.Client logger *slog.Logger}// Create a new server instancefunc NewServer(apiKey string) *Server { return &Server{ client: openai.NewClient(apiKey), logger: slog.New(slog.NewJSONHandler(os.Stdout, nil)), }}// Handle Chat Completion endpointfunc (s *Server) handleChatCompletion(c *gin.Context) { var request ChatCompletionRequest if err := c.ShouldBindJSON(&request); err != nil { s.sendError(c, http.StatusBadRequest, err) return } if !request.Stream { s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming")) return } // Set SSE headers c.Header("Content-Type", "text/event-stream") responseChan := make(chan any, 100) errorChan := make(chan error, 1) go func() { messages := make([]openai.ChatCompletionMessage, len(request.Messages)) for i, msg := range request.Messages { if strContent, ok := msg.Content.(string); ok { messages[i] = openai.ChatCompletionMessage{ Role: msg.Role, Content: strContent, } } } req := openai.ChatCompletionRequest{ Model: request.Model, Messages: messages, Stream: true, } if len(request.Tools) > 0 { tools := make([]openai.Tool, len(request.Tools)) for i, tool := range request.Tools { tools[i] = openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: tool.Function.Name, Description: tool.Function.Description, Parameters: tool.Function.Parameters, }, } } req.Tools = tools } stream, err := s.client.CreateChatCompletionStream(c.Request.Context(), req) if err != nil { errorChan <- err return } defer stream.Close() for { response, err := stream.Recv() if err == io.EOF { break } if err != nil { errorChan <- err return } responseChan <- response } close(responseChan) }() for { select { case chunk, ok := <-responseChan: if !ok { c.SSEvent("data", "[DONE]") return } data, _ := json.Marshal(chunk) c.SSEvent("data", string(data)) case err := <-errorChan: s.logger.Error("Error in chat completion stream", "err", err) s.sendError(c, http.StatusInternalServerError, err) return } }}// Send error response to clientfunc (s *Server) sendError(c *gin.Context, status int, err error) { c.JSON(status, gin.H{"detail": err.Error()})}func main() { // Initialize server server := NewServer(os.Getenv("YOUR_LLM_API_KEY")) // Initialize Gin router r := gin.Default() // Set routes r.POST("/chat/completions", server.handleChatCompletion) // Start server r.Run(":8000")}
When calling the POST method to Start a conversational AI agent, use the LLM configuration to point your agent to the custom service:
If accessing your custom LLM service requires identity verification, provide the authentication information in the api_key
field.
Advanced features
To integrate advanced features such as Retrieval-Augmented Generation and generating outputs in multimodal forms, refer to the following sections.
Retrieval-Augmented Generation
To improve the accuracy and relevance of the agent's responses, use the Retrieval-Augmented Generation (RAG) feature. This feature allows your custom LLM to retrieve information from a specific knowledge base and use the retrieved results as context for generating responses.
The following example simulates the process of retrieving and returning content from a knowledge base and creates the /rag/chat/completions
endpoint to incorporate RAG retrieval results when generating responses with the LLM.
- Python
- Go
async def perform_rag_retrieval(messages: Optional[Dict]) -> str: """ Retrieve relevant content from the knowledge base using the RAG model. Args: messages: The original message list. Returns: str: The retrieved text content. """ # TODO: Implement the actual RAG retrieval logic. # You can choose the first or last message from the message list as the query, # then send the query to the RAG model to retrieve relevant content. # Return the retrieval result. return "This is relevant content retrieved from the knowledge base."def refact_messages(context: str, messages: Optional[Dict] = None) -> Optional[Dict]: """ Modify the message list by adding the retrieved context to the original messages. Args: context: The retrieved context. messages: The original message list. Returns: List: The modified message list. """ # TODO: Implement the actual message modification logic. # This should add the retrieved context to the original message list. return messages# Random waiting messages.waiting_messages = [ "Just a moment, I'm thinking...", "Let me think about that for a second...", "Good question, let me find out...",]@app.post("/rag/chat/completions")async def create_rag_chat_completion(request: ChatCompletionRequest): try: logger.info(f"Received RAG request: {request.model_dump_json()}") if not request.stream: raise HTTPException( status_code=400, detail="chat completions require streaming" ) async def generate(): # First, send a "please wait" prompt. waiting_message = { "id": "waiting_msg", "choices": [ { "index": 0, "delta": { "role": "assistant", "content": random.choice(waiting_messages), }, "finish_reason": None, } ], } yield f"data: {json.dumps(waiting_message)}\n\n" # Perform RAG retrieval. retrieved_context = await perform_rag_retrieval(request.messages) # Modify messages. refacted_messages = refact_messages(retrieved_context, request.messages) # Request LLM completion. client = AsyncOpenAI(api_key=os.getenv("<YOUR_LLM_API_KEY>")) response = await client.chat.completions.create( model=request.model, messages=refacted_messages, tool_choice=( request.tool_choice if request.tools and request.tool_choice else None ), tools=request.tools if request.tools else None, modalities=request.modalities, audio=request.audio, response_format=request.response_format, stream=True, # Force streaming. stream_options=request.stream_options, ) try: async for chunk in response: logger.debug(f"Received RAG chunk: {chunk}") yield f"data: {json.dumps(chunk.to_dict())}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: logger.info("RAG stream was cancelled") raise return StreamingResponse(generate(), media_type="text/event-stream") except asyncio.CancelledError: logger.info("RAG request was cancelled") raise HTTPException(status_code=499, detail="Request was cancelled") except Exception as e: traceback_str = "".join(traceback.format_tb(e.__traceback__)) error_message = f"{str(e)}\n{traceback_str}" logger.error(error_message) raise HTTPException(status_code=500, detail=error_message)
// Handle RAG Chat Completion endpointfunc (s *Server) handleRAGChatCompletion(c *gin.Context) { var request ChatCompletionRequest if err := c.ShouldBindJSON(&request); err != nil { s.sendError(c, http.StatusBadRequest, err) return } if !request.Stream { s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming")) return } // Set SSE headers c.Header("Content-Type", "text/event-stream") // First send a "please wait" prompt waitingMsg := map[string]any{ "id": "waiting_msg", "choices": []map[string]any{ { "index": 0, "delta": map[string]any{ "role": "assistant", "content": waitingMessages[rand.Intn(len(waitingMessages))], }, "finish_reason": nil, }, }, } data, _ := json.Marshal(waitingMsg) c.SSEvent("data", string(data)) // Perform RAG retrieval retrievedContext, err := s.performRAGRetrieval(request.Messages) if err != nil { s.logger.Error("Failed to perform RAG retrieval", "err", err) s.sendError(c, http.StatusInternalServerError, err) return } // Adjust messages refactedMessages := s.refactMessages(retrievedContext, request.Messages) // Convert messages to OpenAI format messages := make([]openai.ChatCompletionMessage, len(refactedMessages)) for i, msg := range refactedMessages { if strContent, ok := msg.Content.(string); ok { messages[i] = openai.ChatCompletionMessage{ Role: msg.Role, Content: strContent, } } } req := openai.ChatCompletionRequest{ Model: request.Model, Messages: messages, Stream: true, } stream, err := s.client.CreateChatCompletionStream(c.Request.Context(), req) if err != nil { s.sendError(c, http.StatusInternalServerError, err) return } defer stream.Close() for { response, err := stream.Recv() if err == io.EOF { break } if err != nil { s.sendError(c, http.StatusInternalServerError, err) return } data, _ := json.Marshal(response) c.SSEvent("data", string(data)) } c.SSEvent("data", "[DONE]")}// performRAGRetrieval uses the RAG model to retrieve relevant content from the knowledge base based on the message list.//// messages: Contains the original message list.//// Returns the retrieved text content and any errors that occurred during retrieval.func (s *Server) performRAGRetrieval(messages []Message) (string, error) { // TODO: Implement actual RAG retrieval logic // You may need to select the first or last message from the message list as a query based on specific requirements, then send the query to the RAG model to retrieve relevant content // Return retrieval results return "This is relevant content retrieved from the knowledge base.", nil}// refactMessages adjusts the message list, adding the retrieved context to the original message list.//// context: Contains the retrieved context.// messages: Contains the original message list.//// Returns the adjusted message list.func (s *Server) refactMessages(context string, messages []Message) []Message { // TODO: Implement actual message adjustment logic // This should add the retrieved context to the original message list // Only return original messages return messages}
When calling the POST method to Start a conversational AI agent, simply point the LLM URL to your RAG interface:
If accessing your custom LLM service requires identity verification, provide the authentication information in the api_key
field.
Multimodal capabilities
Conversational AI Engine supports LLMs in generating output in multimodal formats, including text and audio. You can create dedicated multimodal interfaces to meet personalized requirements.
The following example demonstrates how to read text and audio files and send them to an LLM to generate an audio response.
- Python
- Go
async def read_text_file(file_path: str) -> str: """ Read a text file and return its content Args: file_path: Path to the text file Returns: str: Content of the text file """ async with aiofiles.open(file_path, "r") as file: content = await file.read() return contentasync def read_pcm_file( file_path: str, sample_rate: int, duration_ms: int) -> List[bytes]: """ Read a PCM file and return a list of audio chunks Args: file_path: Path to the PCM file sample_rate: Sampling rate of the audio duration_ms: Duration of each audio chunk in milliseconds Returns: List: List of audio chunks """ async with aiofiles.open(file_path, "rb") as file: content = await file.read() chunk_size = int(sample_rate * 2 * (duration_ms / 1000)) return [content[i : i + chunk_size] for i in range(0, len(content), chunk_size)]@app.post("/audio/chat/completions")async def create_audio_chat_completion(request: ChatCompletionRequest): try: logger.info(f"Received audio request: {request.model_dump_json()}") if not request.stream: raise HTTPException( status_code=400, detail="chat completions require streaming" ) # Example usage, reading text and audio files # Please replace with your actual logic text_file_path = "./file.txt" pcm_file_path = "./file.pcm" sample_rate = 16000 # Example sampling rate duration_ms = 40 # 40ms audio chunk text_content = await read_text_file(text_file_path) audio_chunks = await read_pcm_file(pcm_file_path, sample_rate, duration_ms) async def generate(): try: # Send text content audio_id = uuid.uuid4().hex text_message = { "id": uuid.uuid4().hex, "choices": [ { "index": 0, "delta": { "audio": { "id": audio_id, "transcript": text_content, }, }, "finish_reason": None, } ], } yield f"data: {json.dumps(text_message)}\n\n" # Send audio chunks for chunk in audio_chunks: audio_message = { "id": uuid.uuid4().hex, "choices": [ { "index": 0, "delta": { "audio": { "id": audio_id, "data": base64.b64encode(chunk).decode("utf-8"), }, }, "finish_reason": None, } ], } yield f"data: {json.dumps(audio_message)}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: logger.info("Audio stream was cancelled") raise return StreamingResponse(generate(), media_type="text/event-stream") except asyncio.CancelledError: logger.info("Audio request was cancelled") raise HTTPException(status_code=499, detail="Request was cancelled") except Exception as e: traceback_str = "".join(traceback.format_tb(e.__traceback__)) error_message = f"{str(e)}\n{traceback_str}" logger.error(error_message) raise HTTPException(status_code=500, detail=error_message)
// Handle Audio Chat Completion endpointfunc (s *Server) handleAudioChatCompletion(c *gin.Context) { var request ChatCompletionRequest if err := c.ShouldBindJSON(&request); err != nil { s.sendError(c, http.StatusBadRequest, err) return } if !request.Stream { s.sendError(c, http.StatusBadRequest, fmt.Errorf("chat completions require streaming")) return } // Set SSE headers c.Header("Content-Type", "text/event-stream") // Read text and audio files textContent, err := s.readTextFile("./file.txt") if err != nil { s.logger.Error("Failed to read text file", "err", err) s.sendError(c, http.StatusInternalServerError, err) return } sampleRate := 16000 // Example sample rate durationMs := 40 // 40ms chunks audioChunks, err := s.readPCMFile("./file.pcm", sampleRate, durationMs) if err != nil { s.logger.Error("Failed to read PCM file", "err", err) s.sendError(c, http.StatusInternalServerError, err) return } // Send text content audioID := uuid.New().String() textMessage := map[string]any{ "id": uuid.New().String(), "choices": []map[string]any{ { "index": 0, "delta": map[string]any{ "audio": map[string]any{ "id": audioID, "transcript": textContent, }, }, "finish_reason": nil, }, }, } data, _ := json.Marshal(textMessage) c.SSEvent("data", string(data)) // Send audio chunks for _, chunk := range audioChunks { audioMessage := map[string]any{ "id": uuid.New().String(), "choices": []map[string]any{ { "index": 0, "delta": map[string]any{ "audio": map[string]any{ "id": audioID, "data": base64.StdEncoding.EncodeToString(chunk), }, }, "finish_reason": nil, }, }, } data, _ := json.Marshal(audioMessage) c.SSEvent("data", string(data)) } c.SSEvent("data", "[DONE]")}// readPCMFile reads a PCM file and returns audio chunks.//// filePath: Specifies the path to the PCM file.// sampleRate: Specifies the sampling rate of the audio.// durationMs: Specifies the duration of each audio chunk in milliseconds.//// Returns a list of audio chunks and any errors that occurred during reading.func (s *Server) readPCMFile(filePath string, sampleRate int, durationMs int) ([][]byte, error) { data, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read PCM file: %w", err) } chunkSize := int(float64(sampleRate) * 2 * float64(durationMs) / 1000.0) if chunkSize == 0 { return nil, fmt.Errorf("invalid chunk size: sample rate %d, duration %dms", sampleRate, durationMs) } chunks := make([][]byte, 0, len(data)/chunkSize+1) for i := 0; i < len(data); i += chunkSize { end := i + chunkSize if end > len(data) { end = len(data) } chunks = append(chunks, data[i:end]) } return chunks, nil}// readTextFile reads a text file and returns its content.//// filePath: Specifies the path to the text file.//// Returns the content of the text file and any errors that occurred during reading.func (s *Server) readTextFile(filePath string) (string, error) { data, err := os.ReadFile(filePath) if err != nil { return "", fmt.Errorf("failed to read text file: %w", err) } return string(data), nil}
When calling the POST method to Start a conversational AI agent, use the following configuration to specify the output_modalities
:
Reference
This section contains content that completes the information on this page, or points you to documentation that explains other aspects to this product.
Sample project
Agora provides open source sample projects on GitHub for your reference. Download the project or view the source code for a more complete example.
Interface standards
Custom LLM services must be compatible with the OpenAI Chat Completions API interface standard:
- Request format: Contains parameters such as model, message, and tool call configuration.
- Response format: Contains the response generated by the model, metadata, and other information
- Streaming response: Compliant with the SSE (Server-Sent Events) specification
For detailed interface standards, please refer to: