diff --git a/go.mod b/go.mod index bca6e75..eb7f5dc 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,6 @@ require ( github.com/charmbracelet/lipgloss v0.10.0 github.com/go-yaml/yaml v2.1.0+incompatible github.com/muesli/reflow v0.3.0 - github.com/sashabaranov/go-openai v1.17.7 github.com/spf13/cobra v1.8.0 github.com/sqids/sqids-go v0.4.1 gopkg.in/yaml.v2 v2.2.2 diff --git a/go.sum b/go.sum index 0cbc36b..1e39f2d 100644 --- a/go.sum +++ b/go.sum @@ -61,8 +61,6 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM= -github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 62b71cb..5bd7205 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -46,7 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba err = nil } } - return response, nil + return response, err } // lookupConversation either returns the conversation found by the diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 467e452..97ea288 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -19,7 +19,7 @@ type Context struct { Config *Config Store ConversationStore - Chroma *tty.ChromaHighlighter + Chroma *tty.ChromaHighlighter EnabledTools []model.Tool } @@ -75,7 +75,8 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl for _, m := range *c.Config.OpenAI.Models { if m == model { openai := &openai.OpenAIClient{ - APIKey: *c.Config.OpenAI.APIKey, + BaseURL: "https://api.openai.com/v1", + APIKey: *c.Config.OpenAI.APIKey, } return openai, nil } diff --git a/pkg/lmcli/provider/openai/openai.go b/pkg/lmcli/provider/openai/openai.go index 0b8224b..6dd053f 100644 --- a/pkg/lmcli/provider/openai/openai.go +++ b/pkg/lmcli/provider/openai/openai.go @@ -1,45 +1,30 @@ package openai import ( + "bufio" + "bytes" "context" "encoding/json" - "errors" "fmt" "io" + "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" - openai "github.com/sashabaranov/go-openai" ) -type OpenAIClient struct { - APIKey string -} - -type OpenAIToolParameters struct { - Type string `json:"type"` - Properties map[string]OpenAIToolParameter `json:"properties,omitempty"` - Required []string `json:"required,omitempty"` -} - -type OpenAIToolParameter struct { - Type string `json:"type"` - Description string `json:"description"` - Enum []string `json:"enum,omitempty"` -} - -func convertTools(tools []model.Tool) []openai.Tool { - openaiTools := make([]openai.Tool, len(tools)) +func convertTools(tools []model.Tool) []Tool { + openaiTools := make([]Tool, len(tools)) for i, tool := range tools { openaiTools[i].Type = "function" - params := make(map[string]OpenAIToolParameter) + params := make(map[string]ToolParameter) var required []string for _, param := range tool.Parameters { - params[param.Name] = OpenAIToolParameter{ + params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, Enum: param.Enum, @@ -49,10 +34,10 @@ func convertTools(tools []model.Tool) []openai.Tool { } } - openaiTools[i].Function = openai.FunctionDefinition{ + openaiTools[i].Function = FunctionDefinition{ Name: tool.Name, Description: tool.Description, - Parameters: OpenAIToolParameters{ + Parameters: ToolParameters{ Type: "object", Properties: params, Required: required, @@ -62,8 +47,8 @@ func convertTools(tools []model.Tool) []openai.Tool { return openaiTools } -func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { - converted := make([]openai.ToolCall, len(toolCalls)) +func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall { + converted := make([]ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].Type = "function" converted[i].ID = call.ID @@ -75,7 +60,7 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { return converted } -func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { +func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall { converted := make([]model.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].ID = call.ID @@ -86,16 +71,15 @@ func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { } func createChatCompletionRequest( - c *OpenAIClient, params model.RequestParameters, messages []model.Message, -) openai.ChatCompletionRequest { - requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) +) ChatCompletionRequest { + requestMessages := make([]ChatCompletionMessage, 0, len(messages)) for _, m := range messages { switch m.Role { case "tool_call": - message := openai.ChatCompletionMessage{} + message := ChatCompletionMessage{} message.Role = "assistant" message.Content = m.Content message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) @@ -103,21 +87,21 @@ func createChatCompletionRequest( case "tool_result": // expand tool_result messages' results into multiple openAI messages for _, result := range m.ToolResults { - message := openai.ChatCompletionMessage{} + message := ChatCompletionMessage{} message.Role = "tool" message.Content = result.Result message.ToolCallID = result.ToolCallID requestMessages = append(requestMessages, message) } default: - message := openai.ChatCompletionMessage{} + message := ChatCompletionMessage{} message.Role = string(m.Role) message.Content = m.Content requestMessages = append(requestMessages, message) } } - request := openai.ChatCompletionRequest{ + request := ChatCompletionRequest{ Model: params.Model, MaxTokens: params.MaxTokens, Temperature: params.Temperature, @@ -136,7 +120,7 @@ func createChatCompletionRequest( func handleToolCalls( params model.RequestParameters, content string, - toolCalls []openai.ToolCall, + toolCalls []ToolCall, callback provider.ReplyCallback, messages []model.Message, ) ([]model.Message, error) { @@ -177,6 +161,14 @@ func handleToolCalls( return messages, nil } +func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + client := &http.Client{} + return client.Do(req.WithContext(ctx)) +} + func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, params model.RequestParameters, @@ -187,14 +179,30 @@ func (c *OpenAIClient) CreateChatCompletion( return "", fmt.Errorf("Can't create completion from no messages") } - client := openai.NewClient(c.APIKey) - req := createChatCompletionRequest(c, params, messages) - resp, err := client.CreateChatCompletion(ctx, req) + req := createChatCompletionRequest(params, messages) + jsonData, err := json.Marshal(req) if err != nil { return "", err } - choice := resp.Choices[0] + httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var completionResp ChatCompletionResponse + err = json.NewDecoder(resp.Body).Decode(&completionResp) + if err != nil { + return "", err + } + + choice := completionResp.Choices[0] var content string lastMessage := messages[len(messages)-1] @@ -236,36 +244,60 @@ func (c *OpenAIClient) CreateChatCompletionStream( return "", fmt.Errorf("Can't create completion from no messages") } - client := openai.NewClient(c.APIKey) - req := createChatCompletionRequest(c, params, messages) + req := createChatCompletionRequest(params, messages) + req.Stream = true - stream, err := client.CreateChatCompletionStream(ctx, req) + jsonData, err := json.Marshal(req) if err != nil { return "", err } - defer stream.Close() + + httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + if err != nil { + return "", err + } + + resp, err := c.sendRequest(ctx, httpReq) + if err != nil { + return "", err + } + defer resp.Body.Close() content := strings.Builder{} - toolCalls := []openai.ToolCall{} + toolCalls := []ToolCall{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } - // Iterate stream segments + reader := bufio.NewReader(resp.Body) for { - response, e := stream.Recv() - if errors.Is(e, io.EOF) { + line, err := reader.ReadBytes('\n') + if err != nil { + if err == io.EOF { + break + } + return "", err + } + + line = bytes.TrimSpace(line) + if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { + continue + } + + line = bytes.TrimPrefix(line, []byte("data: ")) + if bytes.Equal(line, []byte("[DONE]")) { break } - if e != nil { - err = e - break + var streamResp ChatCompletionStreamResponse + err = json.Unmarshal(line, &streamResp) + if err != nil { + return "", err } - delta := response.Choices[0].Delta + delta := streamResp.Choices[0].Delta if len(delta.ToolCalls) > 0 { // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { @@ -278,7 +310,8 @@ func (c *OpenAIClient) CreateChatCompletionStream( toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments } } - } else { + } + if len(delta.Content) > 0 { output <- delta.Content content.WriteString(delta.Content) } @@ -301,5 +334,5 @@ func (c *OpenAIClient) CreateChatCompletionStream( } } - return content.String(), err + return content.String(), nil } diff --git a/pkg/lmcli/provider/openai/types.go b/pkg/lmcli/provider/openai/types.go new file mode 100644 index 0000000..27eee03 --- /dev/null +++ b/pkg/lmcli/provider/openai/types.go @@ -0,0 +1,71 @@ +package openai + +type OpenAIClient struct { + APIKey string + BaseURL string +} + +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +type ToolCall struct { + Type string `json:"type"` + ID string `json:"id"` + Index *int `json:"index,omitempty"` + Function FunctionDefinition `json:"function"` +} + +type FunctionDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters ToolParameters `json:"parameters"` + Arguments string `json:"arguments,omitempty"` +} + +type ToolParameters struct { + Type string `json:"type"` + Properties map[string]ToolParameter `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type ToolParameter struct { + Type string `json:"type"` + Description string `json:"description"` + Enum []string `json:"enum,omitempty"` +} + +type Tool struct { + Type string `json:"type"` + Function FunctionDefinition `json:"function"` +} + +type ChatCompletionRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + Messages []ChatCompletionMessage `json:"messages"` + N int `json:"n"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type ChatCompletionChoice struct { + Message ChatCompletionMessage `json:"message"` +} + +type ChatCompletionResponse struct { + Choices []ChatCompletionChoice `json:"choices"` +} + +type ChatCompletionStreamChoice struct { + Delta ChatCompletionMessage `json:"delta"` +} + +type ChatCompletionStreamResponse struct { + Choices []ChatCompletionStreamChoice `json:"choices"` +}