2024-02-21 21:55:38 -07:00
|
|
|
package openai
|
|
|
|
|
|
|
|
import (
|
2024-04-29 00:14:21 -06:00
|
|
|
"bufio"
|
|
|
|
"bytes"
|
2024-02-21 21:55:38 -07:00
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
2024-04-29 00:14:21 -06:00
|
|
|
"net/http"
|
2024-02-21 21:55:38 -07:00
|
|
|
"strings"
|
|
|
|
|
2024-06-09 10:42:53 -06:00
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
2024-02-21 21:55:38 -07:00
|
|
|
)
|
|
|
|
|
2024-06-22 19:48:31 -06:00
|
|
|
type OpenAIClient struct {
|
|
|
|
APIKey string
|
|
|
|
BaseURL string
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionMessage struct {
|
|
|
|
Role string `json:"role"`
|
|
|
|
Content string `json:"content,omitempty"`
|
|
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolCall struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
ID string `json:"id"`
|
|
|
|
Index *int `json:"index,omitempty"`
|
|
|
|
Function FunctionDefinition `json:"function"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type FunctionDefinition struct {
|
|
|
|
Name string `json:"name"`
|
|
|
|
Description string `json:"description"`
|
|
|
|
Parameters ToolParameters `json:"parameters"`
|
|
|
|
Arguments string `json:"arguments,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolParameters struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
|
|
Required []string `json:"required,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ToolParameter struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Description string `json:"description"`
|
|
|
|
Enum []string `json:"enum,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Tool struct {
|
|
|
|
Type string `json:"type"`
|
|
|
|
Function FunctionDefinition `json:"function"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionRequest struct {
|
|
|
|
Model string `json:"model"`
|
|
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
|
|
N int `json:"n"`
|
|
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
|
|
ToolChoice string `json:"tool_choice,omitempty"`
|
|
|
|
Stream bool `json:"stream,omitempty"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionChoice struct {
|
|
|
|
Message ChatCompletionMessage `json:"message"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionResponse struct {
|
|
|
|
Choices []ChatCompletionChoice `json:"choices"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionStreamChoice struct {
|
|
|
|
Delta ChatCompletionMessage `json:"delta"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type ChatCompletionStreamResponse struct {
|
|
|
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
2024-04-29 00:14:21 -06:00
|
|
|
openaiTools := make([]Tool, len(tools))
|
2024-02-21 21:55:38 -07:00
|
|
|
for i, tool := range tools {
|
|
|
|
openaiTools[i].Type = "function"
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
params := make(map[string]ToolParameter)
|
2024-02-21 21:55:38 -07:00
|
|
|
var required []string
|
|
|
|
|
|
|
|
for _, param := range tool.Parameters {
|
2024-04-29 00:14:21 -06:00
|
|
|
params[param.Name] = ToolParameter{
|
2024-02-21 21:55:38 -07:00
|
|
|
Type: param.Type,
|
|
|
|
Description: param.Description,
|
|
|
|
Enum: param.Enum,
|
|
|
|
}
|
|
|
|
if param.Required {
|
|
|
|
required = append(required, param.Name)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
openaiTools[i].Function = FunctionDefinition{
|
2024-02-21 21:55:38 -07:00
|
|
|
Name: tool.Name,
|
|
|
|
Description: tool.Description,
|
2024-04-29 00:14:21 -06:00
|
|
|
Parameters: ToolParameters{
|
2024-02-21 21:55:38 -07:00
|
|
|
Type: "object",
|
|
|
|
Properties: params,
|
|
|
|
Required: required,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return openaiTools
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
2024-04-29 00:14:21 -06:00
|
|
|
converted := make([]ToolCall, len(toolCalls))
|
2024-02-21 21:55:38 -07:00
|
|
|
for i, call := range toolCalls {
|
|
|
|
converted[i].Type = "function"
|
|
|
|
converted[i].ID = call.ID
|
|
|
|
converted[i].Function.Name = call.Name
|
|
|
|
|
|
|
|
json, _ := json.Marshal(call.Parameters)
|
|
|
|
converted[i].Function.Arguments = string(json)
|
|
|
|
}
|
|
|
|
return converted
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
|
|
|
converted := make([]api.ToolCall, len(toolCalls))
|
2024-02-21 21:55:38 -07:00
|
|
|
for i, call := range toolCalls {
|
|
|
|
converted[i].ID = call.ID
|
|
|
|
converted[i].Name = call.Function.Name
|
|
|
|
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
|
|
|
}
|
|
|
|
return converted
|
|
|
|
}
|
|
|
|
|
|
|
|
func createChatCompletionRequest(
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
2024-04-29 00:14:21 -06:00
|
|
|
) ChatCompletionRequest {
|
|
|
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
2024-02-21 21:55:38 -07:00
|
|
|
|
|
|
|
for _, m := range messages {
|
|
|
|
switch m.Role {
|
|
|
|
case "tool_call":
|
2024-04-29 00:14:21 -06:00
|
|
|
message := ChatCompletionMessage{}
|
2024-02-21 21:55:38 -07:00
|
|
|
message.Role = "assistant"
|
|
|
|
message.Content = m.Content
|
|
|
|
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
|
|
|
requestMessages = append(requestMessages, message)
|
|
|
|
case "tool_result":
|
|
|
|
// expand tool_result messages' results into multiple openAI messages
|
|
|
|
for _, result := range m.ToolResults {
|
2024-04-29 00:14:21 -06:00
|
|
|
message := ChatCompletionMessage{}
|
2024-02-21 21:55:38 -07:00
|
|
|
message.Role = "tool"
|
|
|
|
message.Content = result.Result
|
|
|
|
message.ToolCallID = result.ToolCallID
|
|
|
|
requestMessages = append(requestMessages, message)
|
|
|
|
}
|
|
|
|
default:
|
2024-04-29 00:14:21 -06:00
|
|
|
message := ChatCompletionMessage{}
|
2024-02-21 21:55:38 -07:00
|
|
|
message.Role = string(m.Role)
|
|
|
|
message.Content = m.Content
|
|
|
|
requestMessages = append(requestMessages, message)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
request := ChatCompletionRequest{
|
2024-02-21 21:55:38 -07:00
|
|
|
Model: params.Model,
|
|
|
|
MaxTokens: params.MaxTokens,
|
|
|
|
Temperature: params.Temperature,
|
|
|
|
Messages: requestMessages,
|
|
|
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(params.ToolBag) > 0 {
|
|
|
|
request.Tools = convertTools(params.ToolBag)
|
|
|
|
request.ToolChoice = "auto"
|
|
|
|
}
|
|
|
|
|
|
|
|
return request
|
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) {
|
2024-04-29 00:14:21 -06:00
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
|
|
|
|
|
|
client := &http.Client{}
|
2024-06-12 02:35:07 -06:00
|
|
|
resp, err := client.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2024-05-05 01:32:35 -06:00
|
|
|
|
|
|
|
if resp.StatusCode != 200 {
|
|
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
2024-02-21 21:55:38 -07:00
|
|
|
func (c *OpenAIClient) CreateChatCompletion(
|
2024-03-12 12:24:05 -06:00
|
|
|
ctx context.Context,
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
|
|
|
) (*api.Message, error) {
|
2024-03-22 11:51:01 -06:00
|
|
|
if len(messages) == 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
2024-03-22 11:51:01 -06:00
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
req := createChatCompletionRequest(params, messages)
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
2024-04-29 00:14:21 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
resp, err := c.sendRequest(httpReq)
|
2024-04-29 00:14:21 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
var completionResp ChatCompletionResponse
|
|
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
choice := completionResp.Choices[0]
|
2024-02-21 21:55:38 -07:00
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
var content string
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
if lastMessage.Role.IsAssistant() {
|
|
|
|
content = lastMessage.Content + choice.Message.Content
|
|
|
|
} else {
|
|
|
|
content = choice.Message.Content
|
|
|
|
}
|
|
|
|
|
2024-02-21 21:55:38 -07:00
|
|
|
toolCalls := choice.Message.ToolCalls
|
|
|
|
if len(toolCalls) > 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleToolCall,
|
|
|
|
Content: content,
|
|
|
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
|
|
}, nil
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleAssistant,
|
|
|
|
Content: content,
|
|
|
|
}, nil
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
2024-03-12 12:24:05 -06:00
|
|
|
ctx context.Context,
|
2024-06-12 02:35:07 -06:00
|
|
|
params api.RequestParameters,
|
|
|
|
messages []api.Message,
|
2024-06-09 10:42:53 -06:00
|
|
|
output chan<- api.Chunk,
|
2024-06-12 02:35:07 -06:00
|
|
|
) (*api.Message, error) {
|
2024-03-22 11:51:01 -06:00
|
|
|
if len(messages) == 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
2024-03-22 11:51:01 -06:00
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
req := createChatCompletionRequest(params, messages)
|
|
|
|
req.Stream = true
|
|
|
|
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
2024-04-29 00:14:21 -06:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
resp, err := c.sendRequest(httpReq)
|
2024-02-21 21:55:38 -07:00
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
2024-04-29 00:14:21 -06:00
|
|
|
defer resp.Body.Close()
|
2024-02-21 21:55:38 -07:00
|
|
|
|
|
|
|
content := strings.Builder{}
|
2024-04-29 00:14:21 -06:00
|
|
|
toolCalls := []ToolCall{}
|
2024-02-21 21:55:38 -07:00
|
|
|
|
2024-03-22 11:51:01 -06:00
|
|
|
lastMessage := messages[len(messages)-1]
|
|
|
|
if lastMessage.Role.IsAssistant() {
|
|
|
|
content.WriteString(lastMessage.Content)
|
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
reader := bufio.NewReader(resp.Body)
|
2024-02-21 21:55:38 -07:00
|
|
|
for {
|
2024-04-29 00:14:21 -06:00
|
|
|
line, err := reader.ReadBytes('\n')
|
|
|
|
if err != nil {
|
|
|
|
if err == io.EOF {
|
|
|
|
break
|
|
|
|
}
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
line = bytes.TrimSpace(line)
|
|
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
|
|
continue
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
|
|
if bytes.Equal(line, []byte("[DONE]")) {
|
2024-02-21 21:55:38 -07:00
|
|
|
break
|
|
|
|
}
|
|
|
|
|
2024-04-29 00:14:21 -06:00
|
|
|
var streamResp ChatCompletionStreamResponse
|
|
|
|
err = json.Unmarshal(line, &streamResp)
|
|
|
|
if err != nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, err
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
|
|
|
|
delta := streamResp.Choices[0].Delta
|
2024-02-21 21:55:38 -07:00
|
|
|
if len(delta.ToolCalls) > 0 {
|
|
|
|
// Construct streamed tool_call arguments
|
|
|
|
for _, tc := range delta.ToolCalls {
|
|
|
|
if tc.Index == nil {
|
2024-06-12 02:35:07 -06:00
|
|
|
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
if len(toolCalls) <= *tc.Index {
|
|
|
|
toolCalls = append(toolCalls, tc)
|
|
|
|
} else {
|
|
|
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
|
|
|
}
|
|
|
|
}
|
2024-04-29 00:14:21 -06:00
|
|
|
}
|
|
|
|
if len(delta.Content) > 0 {
|
2024-06-09 14:45:18 -06:00
|
|
|
output <- api.Chunk{
|
|
|
|
Content: delta.Content,
|
|
|
|
TokenCount: 1,
|
2024-06-08 17:37:58 -06:00
|
|
|
}
|
2024-02-21 21:55:38 -07:00
|
|
|
content.WriteString(delta.Content)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(toolCalls) > 0 {
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleToolCall,
|
|
|
|
Content: content.String(),
|
|
|
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
|
|
}, nil
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|
|
|
|
|
2024-06-12 02:35:07 -06:00
|
|
|
return &api.Message{
|
|
|
|
Role: api.MessageRoleAssistant,
|
|
|
|
Content: content.String(),
|
|
|
|
}, nil
|
2024-02-21 21:55:38 -07:00
|
|
|
}
|