package openai import ( "context" "encoding/json" "errors" "fmt" "io" "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" openai "github.com/sashabaranov/go-openai" ) type OpenAIClient struct { APIKey string } type OpenAIToolParameters struct { Type string `json:"type"` Properties map[string]OpenAIToolParameter `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } type OpenAIToolParameter struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` } func convertTools(tools []model.Tool) []openai.Tool { openaiTools := make([]openai.Tool, len(tools)) for i, tool := range tools { openaiTools[i].Type = "function" params := make(map[string]OpenAIToolParameter) var required []string for _, param := range tool.Parameters { params[param.Name] = OpenAIToolParameter{ Type: param.Type, Description: param.Description, Enum: param.Enum, } if param.Required { required = append(required, param.Name) } } openaiTools[i].Function = openai.FunctionDefinition{ Name: tool.Name, Description: tool.Description, Parameters: OpenAIToolParameters{ Type: "object", Properties: params, Required: required, }, } } return openaiTools } func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { converted := make([]openai.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].Type = "function" converted[i].ID = call.ID converted[i].Function.Name = call.Name json, _ := json.Marshal(call.Parameters) converted[i].Function.Arguments = string(json) } return converted } func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { converted := make([]model.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].ID = call.ID converted[i].Name = call.Function.Name json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters) } return converted } func createChatCompletionRequest( c *OpenAIClient, params model.RequestParameters, messages []model.Message, ) openai.ChatCompletionRequest { requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) for _, m := range messages { switch m.Role { case "tool_call": message := openai.ChatCompletionMessage{} message.Role = "assistant" message.Content = m.Content message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) requestMessages = append(requestMessages, message) case "tool_result": // expand tool_result messages' results into multiple openAI messages for _, result := range m.ToolResults { message := openai.ChatCompletionMessage{} message.Role = "tool" message.Content = result.Result message.ToolCallID = result.ToolCallID requestMessages = append(requestMessages, message) } default: message := openai.ChatCompletionMessage{} message.Role = string(m.Role) message.Content = m.Content requestMessages = append(requestMessages, message) } } request := openai.ChatCompletionRequest{ Model: params.Model, MaxTokens: params.MaxTokens, Temperature: params.Temperature, Messages: requestMessages, N: 1, // limit responses to 1 "choice". we use choices[0] to reference it } if len(params.ToolBag) > 0 { request.Tools = convertTools(params.ToolBag) request.ToolChoice = "auto" } return request } func handleToolCalls( params model.RequestParameters, content string, toolCalls []openai.ToolCall, ) ([]model.Message, error) { toolCall := model.Message{ Role: model.MessageRoleToolCall, Content: content, ToolCalls: convertToolCallToAPI(toolCalls), } toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) if err != nil { return nil, err } toolResult := model.Message{ Role: model.MessageRoleToolResult, ToolResults: toolResults, } return []model.Message{toolCall, toolResult}, nil } func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, ) (string, error) { client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) resp, err := client.CreateChatCompletion(ctx, req) if err != nil { return "", err } choice := resp.Choices[0] toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { results, err := handleToolCalls(params, choice.Message.Content, toolCalls) if err != nil { return "", err } if results != nil { *replies = append(*replies, results...) } // Recurse into CreateChatCompletion with the tool call replies messages = append(messages, results...) return c.CreateChatCompletion(ctx, params, messages, replies) } if replies != nil { *replies = append(*replies, model.Message{ Role: model.MessageRoleAssistant, Content: choice.Message.Content, }) } // Return the user-facing message. return choice.Message.Content, nil } func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params model.RequestParameters, messages []model.Message, replies *[]model.Message, output chan<- string, ) (string, error) { client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(c, params, messages) stream, err := client.CreateChatCompletionStream(ctx, req) if err != nil { return "", err } defer stream.Close() content := strings.Builder{} toolCalls := []openai.ToolCall{} // Iterate stream segments for { response, e := stream.Recv() if errors.Is(e, io.EOF) { break } if e != nil { err = e break } delta := response.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { return "", fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) } else { toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } } else { output <- delta.Content content.WriteString(delta.Content) } } if len(toolCalls) > 0 { results, err := handleToolCalls(params, content.String(), toolCalls) if err != nil { return content.String(), err } if results != nil { *replies = append(*replies, results...) } // Recurse into CreateChatCompletionStream with the tool call replies messages = append(messages, results...) return c.CreateChatCompletionStream(ctx, params, messages, replies, output) } if replies != nil { *replies = append(*replies, model.Message{ Role: model.MessageRoleAssistant, Content: content.String(), }) } return content.String(), err }