Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
344 lines
8.6 KiB
Go
344 lines
8.6 KiB
Go
package openai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
)
|
|
|
|
type OpenAIClient struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Headers map[string]string
|
|
}
|
|
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content,omitempty"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
Type string `json:"type"`
|
|
ID string `json:"id"`
|
|
Index *int `json:"index,omitempty"`
|
|
Function FunctionDefinition `json:"function"`
|
|
}
|
|
|
|
type FunctionDefinition struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Parameters ToolParameters `json:"parameters"`
|
|
Arguments string `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type ToolParameters struct {
|
|
Type string `json:"type"`
|
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
Required []string `json:"required,omitempty"`
|
|
}
|
|
|
|
type ToolParameter struct {
|
|
Type string `json:"type"`
|
|
Description string `json:"description"`
|
|
Enum []string `json:"enum,omitempty"`
|
|
}
|
|
|
|
type Tool struct {
|
|
Type string `json:"type"`
|
|
Function FunctionDefinition `json:"function"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
MaxTokens int `json:"max_tokens,omitempty"`
|
|
Temperature float32 `json:"temperature,omitempty"`
|
|
Messages []ChatCompletionMessage `json:"messages"`
|
|
N int `json:"n"`
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
ToolChoice string `json:"tool_choice,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChoice struct {
|
|
Message ChatCompletionMessage `json:"message"`
|
|
}
|
|
|
|
type ChatCompletionResponse struct {
|
|
Choices []ChatCompletionChoice `json:"choices"`
|
|
}
|
|
|
|
type ChatCompletionStreamChoice struct {
|
|
Delta ChatCompletionMessage `json:"delta"`
|
|
}
|
|
|
|
type ChatCompletionStreamResponse struct {
|
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
|
}
|
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
|
openaiTools := make([]Tool, len(tools))
|
|
for i, tool := range tools {
|
|
openaiTools[i].Type = "function"
|
|
|
|
params := make(map[string]ToolParameter)
|
|
var required []string
|
|
|
|
for _, param := range tool.Parameters {
|
|
params[param.Name] = ToolParameter{
|
|
Type: param.Type,
|
|
Description: param.Description,
|
|
Enum: param.Enum,
|
|
}
|
|
if param.Required {
|
|
required = append(required, param.Name)
|
|
}
|
|
}
|
|
|
|
openaiTools[i].Function = FunctionDefinition{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: ToolParameters{
|
|
Type: "object",
|
|
Properties: params,
|
|
Required: required,
|
|
},
|
|
}
|
|
}
|
|
return openaiTools
|
|
}
|
|
|
|
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
|
converted := make([]ToolCall, len(toolCalls))
|
|
for i, call := range toolCalls {
|
|
converted[i].Type = "function"
|
|
converted[i].ID = call.ID
|
|
converted[i].Function.Name = call.Name
|
|
|
|
json, _ := json.Marshal(call.Parameters)
|
|
converted[i].Function.Arguments = string(json)
|
|
}
|
|
return converted
|
|
}
|
|
|
|
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
|
converted := make([]api.ToolCall, len(toolCalls))
|
|
for i, call := range toolCalls {
|
|
converted[i].ID = call.ID
|
|
converted[i].Name = call.Function.Name
|
|
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
|
}
|
|
return converted
|
|
}
|
|
|
|
func createChatCompletionRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) ChatCompletionRequest {
|
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
|
|
|
for _, m := range messages {
|
|
switch m.Role {
|
|
case "tool_call":
|
|
message := ChatCompletionMessage{}
|
|
message.Role = "assistant"
|
|
message.Content = m.Content
|
|
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
|
requestMessages = append(requestMessages, message)
|
|
case "tool_result":
|
|
// expand tool_result messages' results into multiple openAI messages
|
|
for _, result := range m.ToolResults {
|
|
message := ChatCompletionMessage{}
|
|
message.Role = "tool"
|
|
message.Content = result.Result
|
|
message.ToolCallID = result.ToolCallID
|
|
requestMessages = append(requestMessages, message)
|
|
}
|
|
default:
|
|
message := ChatCompletionMessage{}
|
|
message.Role = string(m.Role)
|
|
message.Content = m.Content
|
|
requestMessages = append(requestMessages, message)
|
|
}
|
|
}
|
|
|
|
request := ChatCompletionRequest{
|
|
Model: params.Model,
|
|
MaxTokens: params.MaxTokens,
|
|
Temperature: params.Temperature,
|
|
Messages: requestMessages,
|
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
|
}
|
|
|
|
if len(params.Toolbox) > 0 {
|
|
request.Tools = convertTools(params.Toolbox)
|
|
request.ToolChoice = "auto"
|
|
}
|
|
|
|
return request
|
|
}
|
|
|
|
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
|
jsonData, err := json.Marshal(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
|
for header, val := range c.Headers {
|
|
req.Header.Set(header, val)
|
|
}
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
func (c *OpenAIClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createChatCompletionRequest(params, messages)
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp ChatCompletionResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
choice := completionResp.Choices[0]
|
|
|
|
var content string
|
|
lastMessage := messages[len(messages)-1]
|
|
if lastMessage.Role.IsAssistant() {
|
|
content = lastMessage.Content + choice.Message.Content
|
|
} else {
|
|
content = choice.Message.Content
|
|
}
|
|
|
|
toolCalls := choice.Message.ToolCalls
|
|
if len(toolCalls) > 0 {
|
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content), nil
|
|
}
|
|
|
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createChatCompletionRequest(params, messages)
|
|
req.Stream = true
|
|
|
|
resp, err := c.sendRequest(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
content := strings.Builder{}
|
|
toolCalls := []ToolCall{}
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
if lastMessage.Role.IsAssistant() {
|
|
content.WriteString(lastMessage.Content)
|
|
}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
continue
|
|
}
|
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
if bytes.Equal(line, []byte("[DONE]")) {
|
|
break
|
|
}
|
|
|
|
var streamResp ChatCompletionStreamResponse
|
|
err = json.Unmarshal(line, &streamResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
delta := streamResp.Choices[0].Delta
|
|
if len(delta.ToolCalls) > 0 {
|
|
// Construct streamed tool_call arguments
|
|
for _, tc := range delta.ToolCalls {
|
|
if tc.Index == nil {
|
|
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
|
}
|
|
if len(toolCalls) <= *tc.Index {
|
|
toolCalls = append(toolCalls, tc)
|
|
} else {
|
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
|
}
|
|
}
|
|
}
|
|
if len(delta.Content) > 0 {
|
|
output <- provider.Chunk{
|
|
Content: delta.Content,
|
|
TokenCount: 1,
|
|
}
|
|
content.WriteString(delta.Content)
|
|
}
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content.String()), nil
|
|
}
|