Private
Public Access
1
0
Files
lmcli/pkg/lmcli/provider/openai/openai.go

273 lines
6.7 KiB
Go

package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
)
type OpenAIClient struct {
APIKey string
}
type OpenAIToolParameters struct {
Type string `json:"type"`
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type OpenAIToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
func convertTools(tools []model.Tool) []openai.Tool {
openaiTools := make([]openai.Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
params := make(map[string]OpenAIToolParameter)
var required []string
for _, param := range tool.Parameters {
params[param.Name] = OpenAIToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
openaiTools[i].Function = openai.FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
Parameters: OpenAIToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
converted := make([]openai.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
c *OpenAIClient,
params model.RequestParameters,
messages []model.Message,
) openai.ChatCompletionRequest {
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
message := openai.ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
message := openai.ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
message := openai.ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
request := openai.ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
request.ToolChoice = "auto"
}
return request
}
func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []openai.ToolCall,
) ([]model.Message, error) {
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
return []model.Message{toolCall, toolResult}, nil
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req)
if err != nil {
return "", err
}
choice := resp.Choices[0]
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
if err != nil {
return "", err
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, replies)
}
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
})
}
// Return the user-facing message.
return choice.Message.Content, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
replies *[]model.Message,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
return "", err
}
defer stream.Close()
content := strings.Builder{}
toolCalls := []openai.ToolCall{}
// Iterate stream segments
for {
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
}
if e != nil {
err = e
break
}
delta := response.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
} else {
output <- delta.Content
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
results, err := handleToolCalls(params, content.String(), toolCalls)
if err != nil {
return content.String(), err
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, replies, output)
}
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})
}
return content.String(), err
}