lmcli/pkg/api/provider/openai/openai.go

353 lines
8.6 KiB
Go
Raw Normal View History

package openai
import (
2024-04-29 00:14:21 -06:00
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
2024-04-29 00:14:21 -06:00
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
type OpenAIClient struct {
APIKey string
BaseURL string
}
type ChatCompletionMessage struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}
type ToolCall struct {
Type string `json:"type"`
ID string `json:"id"`
Index *int `json:"index,omitempty"`
Function FunctionDefinition `json:"function"`
}
type FunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolParameters `json:"parameters"`
Arguments string `json:"arguments,omitempty"`
}
type ToolParameters struct {
Type string `json:"type"`
Properties map[string]ToolParameter `json:"properties,omitempty"`
Required []string `json:"required,omitempty"`
}
type ToolParameter struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}
type Tool struct {
Type string `json:"type"`
Function FunctionDefinition `json:"function"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
Messages []ChatCompletionMessage `json:"messages"`
N int `json:"n"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice string `json:"tool_choice,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionChoice struct {
Message ChatCompletionMessage `json:"message"`
}
type ChatCompletionResponse struct {
Choices []ChatCompletionChoice `json:"choices"`
}
type ChatCompletionStreamChoice struct {
Delta ChatCompletionMessage `json:"delta"`
}
type ChatCompletionStreamResponse struct {
Choices []ChatCompletionStreamChoice `json:"choices"`
}
func convertTools(tools []api.ToolSpec) []Tool {
2024-04-29 00:14:21 -06:00
openaiTools := make([]Tool, len(tools))
for i, tool := range tools {
openaiTools[i].Type = "function"
2024-04-29 00:14:21 -06:00
params := make(map[string]ToolParameter)
var required []string
for _, param := range tool.Parameters {
2024-04-29 00:14:21 -06:00
params[param.Name] = ToolParameter{
Type: param.Type,
Description: param.Description,
Enum: param.Enum,
}
if param.Required {
required = append(required, param.Name)
}
}
2024-04-29 00:14:21 -06:00
openaiTools[i].Function = FunctionDefinition{
Name: tool.Name,
Description: tool.Description,
2024-04-29 00:14:21 -06:00
Parameters: ToolParameters{
Type: "object",
Properties: params,
Required: required,
},
}
}
return openaiTools
}
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
2024-04-29 00:14:21 -06:00
converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].Type = "function"
converted[i].ID = call.ID
converted[i].Function.Name = call.Name
json, _ := json.Marshal(call.Parameters)
converted[i].Function.Arguments = string(json)
}
return converted
}
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
converted := make([]api.ToolCall, len(toolCalls))
for i, call := range toolCalls {
converted[i].ID = call.ID
converted[i].Name = call.Function.Name
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
}
return converted
}
func createChatCompletionRequest(
params api.RequestParameters,
messages []api.Message,
2024-04-29 00:14:21 -06:00
) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
for _, m := range messages {
switch m.Role {
case "tool_call":
2024-04-29 00:14:21 -06:00
message := ChatCompletionMessage{}
message.Role = "assistant"
message.Content = m.Content
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
requestMessages = append(requestMessages, message)
case "tool_result":
// expand tool_result messages' results into multiple openAI messages
for _, result := range m.ToolResults {
2024-04-29 00:14:21 -06:00
message := ChatCompletionMessage{}
message.Role = "tool"
message.Content = result.Result
message.ToolCallID = result.ToolCallID
requestMessages = append(requestMessages, message)
}
default:
2024-04-29 00:14:21 -06:00
message := ChatCompletionMessage{}
message.Role = string(m.Role)
message.Content = m.Content
requestMessages = append(requestMessages, message)
}
}
2024-04-29 00:14:21 -06:00
request := ChatCompletionRequest{
Model: params.Model,
MaxTokens: params.MaxTokens,
Temperature: params.Temperature,
Messages: requestMessages,
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
}
if len(params.ToolBag) > 0 {
request.Tools = convertTools(params.ToolBag)
request.ToolChoice = "auto"
}
return request
}
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
jsonData, err := json.Marshal(r)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
2024-04-29 00:14:21 -06:00
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
2024-05-05 01:32:35 -06:00
if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes))
}
return resp, err
2024-04-29 00:14:21 -06:00
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
2024-04-29 00:14:21 -06:00
req := createChatCompletionRequest(params, messages)
resp, err := c.sendRequest(ctx, req)
2024-04-29 00:14:21 -06:00
if err != nil {
return nil, err
2024-04-29 00:14:21 -06:00
}
defer resp.Body.Close()
var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil {
return nil, err
}
2024-04-29 00:14:21 -06:00
choice := completionResp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params api.RequestParameters,
messages []api.Message,
output chan<- api.Chunk,
) (*api.Message, error) {
if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages")
}
2024-04-29 00:14:21 -06:00
req := createChatCompletionRequest(params, messages)
req.Stream = true
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, err
}
2024-04-29 00:14:21 -06:00
defer resp.Body.Close()
content := strings.Builder{}
2024-04-29 00:14:21 -06:00
toolCalls := []ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
2024-04-29 00:14:21 -06:00
reader := bufio.NewReader(resp.Body)
for {
2024-04-29 00:14:21 -06:00
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return nil, err
2024-04-29 00:14:21 -06:00
}
line = bytes.TrimSpace(line)
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
continue
}
2024-04-29 00:14:21 -06:00
line = bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(line, []byte("[DONE]")) {
break
}
2024-04-29 00:14:21 -06:00
var streamResp ChatCompletionStreamResponse
err = json.Unmarshal(line, &streamResp)
if err != nil {
return nil, err
2024-04-29 00:14:21 -06:00
}
delta := streamResp.Choices[0].Delta
if len(delta.ToolCalls) > 0 {
// Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls {
if tc.Index == nil {
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
}
if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc)
} else {
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
}
}
2024-04-29 00:14:21 -06:00
}
if len(delta.Content) > 0 {
output <- api.Chunk{
Content: delta.Content,
TokenCount: 1,
}
content.WriteString(delta.Content)
}
}
if len(toolCalls) > 0 {
return &api.Message{
Role: api.MessageRoleToolCall,
Content: content.String(),
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
}
return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
}