package openai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/provider" ) type OpenAIClient struct { APIKey string BaseURL string Headers map[string]string } type ChatCompletionMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"` } type ToolCall struct { Type string `json:"type"` ID string `json:"id"` Index *int `json:"index,omitempty"` Function FunctionDefinition `json:"function"` } type FunctionDefinition struct { Name string `json:"name"` Description string `json:"description"` Parameters ToolParameters `json:"parameters"` Arguments string `json:"arguments,omitempty"` } type ToolParameters struct { Type string `json:"type"` Properties map[string]ToolParameter `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } type ToolParameter struct { Type string `json:"type"` Description string `json:"description"` Enum []string `json:"enum,omitempty"` } type Tool struct { Type string `json:"type"` Function FunctionDefinition `json:"function"` } type ChatCompletionRequest struct { Model string `json:"model"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float32 `json:"temperature,omitempty"` Messages []ChatCompletionMessage `json:"messages"` N int `json:"n"` Tools []Tool `json:"tools,omitempty"` ToolChoice string `json:"tool_choice,omitempty"` Stream bool `json:"stream,omitempty"` } type ChatCompletionChoice struct { Message ChatCompletionMessage `json:"message"` } type ChatCompletionResponse struct { Choices []ChatCompletionChoice `json:"choices"` } type ChatCompletionStreamChoice struct { Delta ChatCompletionMessage `json:"delta"` } type ChatCompletionStreamResponse struct { Choices []ChatCompletionStreamChoice `json:"choices"` } func convertTools(tools []api.ToolSpec) []Tool { openaiTools := make([]Tool, len(tools)) for i, tool := range tools { openaiTools[i].Type = "function" params := make(map[string]ToolParameter) var required []string for _, param := range tool.Parameters { params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, Enum: param.Enum, } if param.Required { required = append(required, param.Name) } } openaiTools[i].Function = FunctionDefinition{ Name: tool.Name, Description: tool.Description, Parameters: ToolParameters{ Type: "object", Properties: params, Required: required, }, } } return openaiTools } func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall { converted := make([]ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].Type = "function" converted[i].ID = call.ID converted[i].Function.Name = call.Name json, _ := json.Marshal(call.Parameters) converted[i].Function.Arguments = string(json) } return converted } func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall { converted := make([]api.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].ID = call.ID converted[i].Name = call.Function.Name json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters) } return converted } func createChatCompletionRequest( params provider.RequestParameters, messages []api.Message, ) ChatCompletionRequest { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) for _, m := range messages { switch m.Role { case "tool_call": message := ChatCompletionMessage{} message.Role = "assistant" message.Content = m.Content message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) requestMessages = append(requestMessages, message) case "tool_result": // expand tool_result messages' results into multiple openAI messages for _, result := range m.ToolResults { message := ChatCompletionMessage{} message.Role = "tool" message.Content = result.Result message.ToolCallID = result.ToolCallID requestMessages = append(requestMessages, message) } default: message := ChatCompletionMessage{} message.Role = string(m.Role) message.Content = m.Content requestMessages = append(requestMessages, message) } } request := ChatCompletionRequest{ Model: params.Model, MaxTokens: params.MaxTokens, Temperature: params.Temperature, Messages: requestMessages, N: 1, // limit responses to 1 "choice". we use choices[0] to reference it } if len(params.Toolbox) > 0 { request.Tools = convertTools(params.Toolbox) request.ToolChoice = "auto" } return request } func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) { jsonData, err := json.Marshal(r) if err != nil { return nil, err } req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.APIKey) for header, val := range c.Headers { req.Header.Set(header, val) } client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("%v", string(bytes)) } return resp, err } func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params provider.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) resp, err := c.sendRequest(ctx, req) if err != nil { return nil, err } defer resp.Body.Close() var completionResp ChatCompletionResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { return nil, err } choice := completionResp.Choices[0] var content string lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content = lastMessage.Content + choice.Message.Content } else { content = choice.Message.Content } toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil } return api.NewMessageWithAssistant(content), nil } func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, params provider.RequestParameters, messages []api.Message, output chan<- provider.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) req.Stream = true resp, err := c.sendRequest(ctx, req) if err != nil { return nil, err } defer resp.Body.Close() content := strings.Builder{} toolCalls := []ToolCall{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } reader := bufio.NewReader(resp.Body) for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return nil, err } line = bytes.TrimSpace(line) if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { continue } line = bytes.TrimPrefix(line, []byte("data: ")) if bytes.Equal(line, []byte("[DONE]")) { break } var streamResp ChatCompletionStreamResponse err = json.Unmarshal(line, &streamResp) if err != nil { return nil, err } delta := streamResp.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { return nil, fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) } else { toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } } if len(delta.Content) > 0 { output <- provider.Chunk{ Content: delta.Content, TokenCount: 1, } content.WriteString(delta.Content) } } if len(toolCalls) > 0 { return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil } return api.NewMessageWithAssistant(content.String()), nil }