Remove go-openai

This commit is contained in:
Matt Low 2024-04-29 06:14:21 +00:00
parent 08a2027332
commit ffe9d299ef
6 changed files with 161 additions and 59 deletions

1
go.mod
View File

@ -9,7 +9,6 @@ require (
github.com/charmbracelet/lipgloss v0.10.0 github.com/charmbracelet/lipgloss v0.10.0
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/muesli/reflow v0.3.0 github.com/muesli/reflow v0.3.0
github.com/sashabaranov/go-openai v1.17.7
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
gopkg.in/yaml.v2 v2.2.2 gopkg.in/yaml.v2 v2.2.2

2
go.sum
View File

@ -61,8 +61,6 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=

View File

@ -46,7 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
err = nil err = nil
} }
} }
return response, nil return response, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the

View File

@ -75,6 +75,7 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
for _, m := range *c.Config.OpenAI.Models { for _, m := range *c.Config.OpenAI.Models {
if m == model { if m == model {
openai := &openai.OpenAIClient{ openai := &openai.OpenAIClient{
BaseURL: "https://api.openai.com/v1",
APIKey: *c.Config.OpenAI.APIKey, APIKey: *c.Config.OpenAI.APIKey,
} }
return openai, nil return openai, nil

View File

@ -1,45 +1,30 @@
package openai package openai
import ( import (
"bufio"
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
) )
type OpenAIClient struct { func convertTools(tools []model.Tool) []Tool {
APIKey string openaiTools := make([]Tool, len(tools))
}
type OpenAIToolParameters struct {
Type string `json:"type"`
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type OpenAIToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []openai.Tool {
openaiTools := make([]openai.Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
openaiTools[i].Type = "function" openaiTools[i].Type = "function"
params := make(map[string]OpenAIToolParameter) params := make(map[string]ToolParameter)
var required []string var required []string
for _, param := range tool.Parameters { for _, param := range tool.Parameters {
params[param.Name] = OpenAIToolParameter{ params[param.Name] = ToolParameter{
Type: param.Type, Type: param.Type,
Description: param.Description, Description: param.Description,
Enum: param.Enum, Enum: param.Enum,
@ -49,10 +34,10 @@ func convertTools(tools []model.Tool) []openai.Tool {
} }
} }
openaiTools[i].Function = openai.FunctionDefinition{ openaiTools[i].Function = FunctionDefinition{
Name: tool.Name, Name: tool.Name,
Description: tool.Description, Description: tool.Description,
Parameters: OpenAIToolParameters{ Parameters: ToolParameters{
Type: "object", Type: "object",
Properties: params, Properties: params,
Required: required, Required: required,
@ -62,8 +47,8 @@ func convertTools(tools []model.Tool) []openai.Tool {
return openaiTools return openaiTools
} }
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall { func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
converted := make([]openai.ToolCall, len(toolCalls)) converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].Type = "function" converted[i].Type = "function"
converted[i].ID = call.ID converted[i].ID = call.ID
@ -75,7 +60,7 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
return converted return converted
} }
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall { func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
converted := make([]model.ToolCall, len(toolCalls)) converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].ID = call.ID converted[i].ID = call.ID
@ -86,16 +71,15 @@ func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
c *OpenAIClient,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
) openai.ChatCompletionRequest { ) ChatCompletionRequest {
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
for _, m := range messages { for _, m := range messages {
switch m.Role { switch m.Role {
case "tool_call": case "tool_call":
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = "assistant" message.Role = "assistant"
message.Content = m.Content message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls) message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
@ -103,21 +87,21 @@ func createChatCompletionRequest(
case "tool_result": case "tool_result":
// expand tool_result messages' results into multiple openAI messages // expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults { for _, result := range m.ToolResults {
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = "tool" message.Role = "tool"
message.Content = result.Result message.Content = result.Result
message.ToolCallID = result.ToolCallID message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message) requestMessages = append(requestMessages, message)
} }
default: default:
message := openai.ChatCompletionMessage{} message := ChatCompletionMessage{}
message.Role = string(m.Role) message.Role = string(m.Role)
message.Content = m.Content message.Content = m.Content
requestMessages = append(requestMessages, message) requestMessages = append(requestMessages, message)
} }
} }
request := openai.ChatCompletionRequest{ request := ChatCompletionRequest{
Model: params.Model, Model: params.Model,
MaxTokens: params.MaxTokens, MaxTokens: params.MaxTokens,
Temperature: params.Temperature, Temperature: params.Temperature,
@ -136,7 +120,7 @@ func createChatCompletionRequest(
func handleToolCalls( func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []openai.ToolCall, toolCalls []ToolCall,
callback provider.ReplyCallback, callback provider.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
@ -177,6 +161,14 @@ func handleToolCalls(
return messages, nil return messages, nil
} }
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{}
return client.Do(req.WithContext(ctx))
}
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
@ -187,14 +179,30 @@ func (c *OpenAIClient) CreateChatCompletion(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(params, messages)
req := createChatCompletionRequest(c, params, messages) jsonData, err := json.Marshal(req)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil { if err != nil {
return "", err return "", err
} }
choice := resp.Choices[0] httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return "", err
}
choice := completionResp.Choices[0]
var content string var content string
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -236,36 +244,60 @@ func (c *OpenAIClient) CreateChatCompletionStream(
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
client := openai.NewClient(c.APIKey) req := createChatCompletionRequest(params, messages)
req := createChatCompletionRequest(c, params, messages) req.Stream = true
stream, err := client.CreateChatCompletionStream(ctx, req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return "", err
} }
defer stream.Close()
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", err
}
resp, err := c.sendRequest(ctx, httpReq)
if err != nil {
return "", err
}
defer resp.Body.Close()
content := strings.Builder{} content := strings.Builder{}
toolCalls := []openai.ToolCall{} toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() { if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content) content.WriteString(lastMessage.Content)
} }
// Iterate stream segments reader := bufio.NewReader(resp.Body)
for { for {
response, e := stream.Recv() line, err := reader.ReadBytes('\n')
if errors.Is(e, io.EOF) { if err != nil {
if err == io.EOF {
break
}
return "", err
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break break
} }
if e != nil { var streamResp ChatCompletionStreamResponse
err = e err = json.Unmarshal(line, &streamResp)
break if err != nil {
return "", err
} }
delta := response.Choices[0].Delta delta := streamResp.Choices[0].Delta
if len(delta.ToolCalls) > 0 { if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments // Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls { for _, tc := range delta.ToolCalls {
@ -278,7 +310,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
} }
} }
} else { }
if len(delta.Content) > 0 {
output <- delta.Content output <- delta.Content
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }
@ -301,5 +334,5 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
return content.String(), err return content.String(), nil
} }

View File

@ -0,0 +1,71 @@
package openai
type OpenAIClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall struct {
Type string `json:"type"`
ID string `json:"id"`
Index *int `json:"index,omitempty"`
Function FunctionDefinition `json:"function"`
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
Arguments string `json:"arguments,omitempty"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Messages []ChatCompletionMessage `json:"messages"`
N int `json:"n"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionChoice struct {
Message ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Choices []ChatCompletionChoice `json:"choices"`
}
type ChatCompletionStreamChoice struct {
Delta ChatCompletionMessage `json:"delta"`
}
type ChatCompletionStreamResponse struct {
Choices []ChatCompletionStreamChoice `json:"choices"`
}