package openai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) func convertTools(tools []model.Tool) []Tool { openaiTools := make([]Tool, len(tools)) for i, tool := range tools { openaiTools[i].Type = "function" params := make(map[string]ToolParameter) var required []string for _, param := range tool.Parameters { params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, Enum: param.Enum, } if param.Required { required = append(required, param.Name) } } openaiTools[i].Function = FunctionDefinition{ Name: tool.Name, Description: tool.Description, Parameters: ToolParameters{ Type: "object", Properties: params, Required: required, }, } } return openaiTools } func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall { converted := make([]ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].Type = "function" converted[i].ID = call.ID converted[i].Function.Name = call.Name json, _ := json.Marshal(call.Parameters) converted[i].Function.Arguments = string(json) } return converted } func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall { converted := make([]model.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].ID = call.ID converted[i].Name = call.Function.Name json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters) } return converted } func createChatCompletionRequest( params model.RequestParameters, messages []model.Message, ) ChatCompletionRequest { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) for _, m := range messages { switch m.Role { case "tool_call": message := ChatCompletionMessage{} message.Role = "assistant" message.Content = m.Content message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) requestMessages = append(requestMessages, message) case "tool_result": // expand tool_result messages' results into multiple openAI messages for _, result := range m.ToolResults { message := ChatCompletionMessage{} message.Role = "tool" message.Content = result.Result message.ToolCallID = result.ToolCallID requestMessages = append(requestMessages, message) } default: message := ChatCompletionMessage{} message.Role = string(m.Role) message.Content = m.Content requestMessages = append(requestMessages, message) } } request := ChatCompletionRequest{ Model: params.Model, MaxTokens: params.MaxTokens, Temperature: params.Temperature, Messages: requestMessages, N: 1, // limit responses to 1 "choice". we use choices[0] to reference it } if len(params.ToolBag) > 0 { request.Tools = convertTools(params.ToolBag) request.ToolChoice = "auto" } return request } func handleToolCalls( params model.RequestParameters, content string, toolCalls []ToolCall, callback api.ReplyCallback, messages []model.Message, ) ([]model.Message, error) { lastMessage := messages[len(messages)-1] continuation := false if lastMessage.Role.IsAssistant() { continuation = true } toolCall := model.Message{ Role: model.MessageRoleToolCall, Content: content, ToolCalls: convertToolCallToAPI(toolCalls), } toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) if err != nil { return nil, err } toolResult := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } if callback != nil { callback(toolCall) callback(toolResult) } if continuation { messages[len(messages)-1] = toolCall } else { messages = append(messages, toolCall) } messages = append(messages, toolResult) return messages, nil } func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.APIKey) client := &http.Client{} resp, err := client.Do(req.WithContext(ctx)) if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("%v", string(bytes)) } return resp, err } func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) jsonData, err := json.Marshal(req) if err != nil { return "", err } httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return "", err } resp, err := c.sendRequest(ctx, httpReq) if err != nil { return "", err } defer resp.Body.Close() var completionResp ChatCompletionResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { return "", err } choice := completionResp.Choices[0] var content string lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content = lastMessage.Content + choice.Message.Content } else { content = choice.Message.Content } toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { messages, err := handleToolCalls(params, content, toolCalls, callback, messages) if err != nil { return content, err } return c.CreateChatCompletion(ctx, params, messages, callback) } if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, Content: content, }) } // Return the user-facing message. return content, nil } func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, callback api.ReplyCallback, output chan<- api.Chunk, ) (string, error) { if len(messages) == 0 { return "", fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) req.Stream = true jsonData, err := json.Marshal(req) if err != nil { return "", err } httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return "", err } resp, err := c.sendRequest(ctx, httpReq) if err != nil { return "", err } defer resp.Body.Close() content := strings.Builder{} toolCalls := []ToolCall{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return "", err } line = bytes.TrimSpace(line) if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { continue } line = bytes.TrimPrefix(line, []byte("data: ")) if bytes.Equal(line, []byte("[DONE]")) { break } var streamResp ChatCompletionStreamResponse err = json.Unmarshal(line, &streamResp) if err != nil { return "", err } delta := streamResp.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { return "", fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) } else { toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } } if len(delta.Content) > 0 { output <- api.Chunk { Content: delta.Content, } content.WriteString(delta.Content) } } if len(toolCalls) > 0 { messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) if err != nil { return content.String(), err } // Recurse into CreateChatCompletionStream with the tool call replies return c.CreateChatCompletionStream(ctx, params, messages, callback, output) } else { if callback != nil { callback(model.Message{ Role: model.MessageRoleAssistant, Content: content.String(), }) } } return content.String(), nil }