diff --git a/pkg/lmcli/tools/dir_tree.go b/pkg/agent/toolbox/dir_tree.go similarity index 87% rename from pkg/lmcli/tools/dir_tree.go rename to pkg/agent/toolbox/dir_tree.go index 7d7d8d9..243d379 100644 --- a/pkg/lmcli/tools/dir_tree.go +++ b/pkg/agent/toolbox/dir_tree.go @@ -1,4 +1,4 @@ -package tools +package toolbox import ( "fmt" @@ -7,8 +7,8 @@ import ( "strconv" "strings" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents. @@ -27,10 +27,10 @@ Example result: } ` -var DirTreeTool = model.Tool{ +var DirTreeTool = api.ToolSpec{ Name: "dir_tree", Description: TREE_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "relative_path", Type: "string", @@ -42,7 +42,7 @@ var DirTreeTool = model.Tool{ Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.", }, }, - Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { var relativeDir string if tmp, ok := args["relative_path"]; ok { relativeDir, ok = tmp.(string) @@ -76,25 +76,25 @@ var DirTreeTool = model.Tool{ }, } -func tree(path string, depth int) model.CallResult { +func tree(path string, depth int) api.CallResult { if path == "" { path = "." } ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } var treeOutput strings.Builder treeOutput.WriteString(path + "\n") err := buildTree(&treeOutput, path, "", depth) if err != nil { - return model.CallResult{ + return api.CallResult{ Message: err.Error(), } } - return model.CallResult{Result: treeOutput.String()} + return api.CallResult{Result: treeOutput.String()} } func buildTree(output *strings.Builder, path string, prefix string, depth int) error { diff --git a/pkg/lmcli/tools/file_insert_lines.go b/pkg/agent/toolbox/file_insert_lines.go similarity index 76% rename from pkg/lmcli/tools/file_insert_lines.go rename to pkg/agent/toolbox/file_insert_lines.go index 513f9a5..17a197a 100644 --- a/pkg/lmcli/tools/file_insert_lines.go +++ b/pkg/agent/toolbox/file_insert_lines.go @@ -1,22 +1,22 @@ -package tools +package toolbox import ( "fmt" "os" "strings" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. Make sure your inserts match the flow and indentation of surrounding content.` -var FileInsertLinesTool = model.Tool{ +var FileInsertLinesTool = api.ToolSpec{ Name: "file_insert_lines", Description: FILE_INSERT_LINES_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "path", Type: "string", @@ -36,7 +36,7 @@ var FileInsertLinesTool = model.Tool{ Required: true, }, }, - Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { tmp, ok := args["path"] if !ok { return "", fmt.Errorf("path parameter to write_file was not included.") @@ -72,27 +72,27 @@ var FileInsertLinesTool = model.Tool{ }, } -func fileInsertLines(path string, position int, content string) model.CallResult { +func fileInsertLines(path string, position int, content string) api.CallResult { ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } // Read the existing file's content data, err := os.ReadFile(path) if err != nil { if !os.IsNotExist(err) { - return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} } _, err = os.Create(path) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} } data = []byte{} } if position < 1 { - return model.CallResult{Message: "start_line cannot be less than 1"} + return api.CallResult{Message: "start_line cannot be less than 1"} } lines := strings.Split(string(data), "\n") @@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) model.CallResult // Join the lines and write back to the file err = os.WriteFile(path, []byte(newContent), 0644) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} } - return model.CallResult{Result: newContent} + return api.CallResult{Result: newContent} } diff --git a/pkg/lmcli/tools/file_replace_lines.go b/pkg/agent/toolbox/file_replace_lines.go similarity index 80% rename from pkg/lmcli/tools/file_replace_lines.go rename to pkg/agent/toolbox/file_replace_lines.go index cdb1def..ad346bb 100644 --- a/pkg/lmcli/tools/file_replace_lines.go +++ b/pkg/agent/toolbox/file_replace_lines.go @@ -1,12 +1,12 @@ -package tools +package toolbox import ( "fmt" "os" "strings" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. @@ -15,10 +15,10 @@ Useful for re-writing snippets/blocks of code or entire functions. Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.` -var FileReplaceLinesTool = model.Tool{ +var FileReplaceLinesTool = api.ToolSpec{ Name: "file_replace_lines", Description: FILE_REPLACE_LINES_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "path", Type: "string", @@ -42,7 +42,7 @@ var FileReplaceLinesTool = model.Tool{ Description: `Content to replace specified range. Omit to remove the specified range.`, }, }, - Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { tmp, ok := args["path"] if !ok { return "", fmt.Errorf("path parameter to write_file was not included.") @@ -87,27 +87,27 @@ var FileReplaceLinesTool = model.Tool{ }, } -func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult { +func fileReplaceLines(path string, startLine int, endLine int, content string) api.CallResult { ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } // Read the existing file's content data, err := os.ReadFile(path) if err != nil { if !os.IsNotExist(err) { - return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} } _, err = os.Create(path) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} } data = []byte{} } if startLine < 1 { - return model.CallResult{Message: "start_line cannot be less than 1"} + return api.CallResult{Message: "start_line cannot be less than 1"} } lines := strings.Split(string(data), "\n") @@ -126,8 +126,8 @@ func fileReplaceLines(path string, startLine int, endLine int, content string) m // Join the lines and write back to the file err = os.WriteFile(path, []byte(newContent), 0644) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} } - return model.CallResult{Result: newContent} + return api.CallResult{Result: newContent} } diff --git a/pkg/lmcli/tools/read_dir.go b/pkg/agent/toolbox/read_dir.go similarity index 80% rename from pkg/lmcli/tools/read_dir.go rename to pkg/agent/toolbox/read_dir.go index 46534e4..02a2cfb 100644 --- a/pkg/lmcli/tools/read_dir.go +++ b/pkg/agent/toolbox/read_dir.go @@ -1,4 +1,4 @@ -package tools +package toolbox import ( "fmt" @@ -6,8 +6,8 @@ import ( "path/filepath" "strings" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). @@ -25,17 +25,17 @@ Example result: For files, size represents the size of the file, in bytes. For directories, size represents the number of entries in that directory.` -var ReadDirTool = model.Tool{ +var ReadDirTool = api.ToolSpec{ Name: "read_dir", Description: READ_DIR_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "relative_dir", Type: "string", Description: "If set, read the contents of a directory relative to the current one.", }, }, - Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { var relativeDir string tmp, ok := args["relative_dir"] if ok { @@ -53,18 +53,18 @@ var ReadDirTool = model.Tool{ }, } -func readDir(path string) model.CallResult { +func readDir(path string) api.CallResult { if path == "" { path = "." } ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } files, err := os.ReadDir(path) if err != nil { - return model.CallResult{ + return api.CallResult{ Message: err.Error(), } } @@ -96,5 +96,5 @@ func readDir(path string) model.CallResult { }) } - return model.CallResult{Result: dirContents} + return api.CallResult{Result: dirContents} } diff --git a/pkg/lmcli/tools/read_file.go b/pkg/agent/toolbox/read_file.go similarity index 75% rename from pkg/lmcli/tools/read_file.go rename to pkg/agent/toolbox/read_file.go index 8164dfa..a35eedb 100644 --- a/pkg/lmcli/tools/read_file.go +++ b/pkg/agent/toolbox/read_file.go @@ -1,12 +1,12 @@ -package tools +package toolbox import ( "fmt" "os" "strings" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory. @@ -21,10 +21,10 @@ Example result: "result": "1\tthe contents\n2\tof the file\n" }` -var ReadFileTool = model.Tool{ +var ReadFileTool = api.ToolSpec{ Name: "read_file", Description: READ_FILE_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "path", Type: "string", @@ -33,7 +33,7 @@ var ReadFileTool = model.Tool{ }, }, - Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { tmp, ok := args["path"] if !ok { return "", fmt.Errorf("Path parameter to read_file was not included.") @@ -51,14 +51,14 @@ var ReadFileTool = model.Tool{ }, } -func readFile(path string) model.CallResult { +func readFile(path string) api.CallResult { ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } data, err := os.ReadFile(path) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} } lines := strings.Split(string(data), "\n") @@ -67,7 +67,7 @@ func readFile(path string) model.CallResult { content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) } - return model.CallResult{ + return api.CallResult{ Result: content.String(), } } diff --git a/pkg/lmcli/tools/util/util.go b/pkg/agent/toolbox/util/util.go similarity index 100% rename from pkg/lmcli/tools/util/util.go rename to pkg/agent/toolbox/util/util.go diff --git a/pkg/lmcli/tools/write_file.go b/pkg/agent/toolbox/write_file.go similarity index 74% rename from pkg/lmcli/tools/write_file.go rename to pkg/agent/toolbox/write_file.go index 7263db5..5f701a7 100644 --- a/pkg/lmcli/tools/write_file.go +++ b/pkg/agent/toolbox/write_file.go @@ -1,11 +1,11 @@ -package tools +package toolbox import ( "fmt" "os" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" + toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. @@ -15,10 +15,10 @@ Example result: "message": "success" }` -var WriteFileTool = model.Tool{ +var WriteFileTool = api.ToolSpec{ Name: "write_file", Description: WRITE_FILE_DESCRIPTION, - Parameters: []model.ToolParameter{ + Parameters: []api.ToolParameter{ { Name: "path", Type: "string", @@ -32,7 +32,7 @@ var WriteFileTool = model.Tool{ Required: true, }, }, - Impl: func(t *model.Tool, args map[string]interface{}) (string, error) { + Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) { tmp, ok := args["path"] if !ok { return "", fmt.Errorf("Path parameter to write_file was not included.") @@ -58,14 +58,14 @@ var WriteFileTool = model.Tool{ }, } -func writeFile(path string, content string) model.CallResult { +func writeFile(path string, content string) api.CallResult { ok, reason := toolutil.IsPathWithinCWD(path) if !ok { - return model.CallResult{Message: reason} + return api.CallResult{Message: reason} } err := os.WriteFile(path, []byte(content), 0644) if err != nil { - return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} + return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} } - return model.CallResult{} + return api.CallResult{} } diff --git a/pkg/agent/tools.go b/pkg/agent/tools.go new file mode 100644 index 0000000..d920b23 --- /dev/null +++ b/pkg/agent/tools.go @@ -0,0 +1,48 @@ +package agent + +import ( + "fmt" + + "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox" + "git.mlow.ca/mlow/lmcli/pkg/api" +) + +var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{ + "dir_tree": toolbox.DirTreeTool, + "read_dir": toolbox.ReadDirTool, + "read_file": toolbox.ReadFileTool, + "write_file": toolbox.WriteFileTool, + "file_insert_lines": toolbox.FileInsertLinesTool, + "file_replace_lines": toolbox.FileReplaceLinesTool, +} + +func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) { + var toolResults []api.ToolResult + for _, call := range calls { + var tool *api.ToolSpec + for i := range available { + if available[i].Name == call.Name { + tool = &available[i] + break + } + } + if tool == nil { + return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name) + } + + // Execute the tool + result, err := tool.Impl(tool, call.Parameters) + if err != nil { + return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err) + } + + toolResult := api.ToolResult{ + ToolCallID: call.ID, + ToolName: call.Name, + Result: result, + } + + toolResults = append(toolResults, toolResult) + } + return toolResults, nil +} diff --git a/pkg/api/api.go b/pkg/api/api.go index 41f52fc..ea34ebb 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -2,35 +2,41 @@ package api import ( "context" - - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" ) -type ReplyCallback func(model.Message) +type ReplyCallback func(Message) type Chunk struct { Content string TokenCount uint } -type ChatCompletionClient interface { +type RequestParameters struct { + Model string + + MaxTokens int + Temperature float32 + TopP float32 + + ToolBag []ToolSpec +} + +type ChatCompletionProvider interface { // CreateChatCompletion requests a response to the provided messages. // Replies are appended to the given replies struct, and the // complete user-facing response is returned as a string. CreateChatCompletion( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback ReplyCallback, - ) (string, error) + params RequestParameters, + messages []Message, + ) (*Message, error) // Like CreateChageCompletion, except the response is streamed via // the output channel as it's received. CreateChatCompletionStream( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback ReplyCallback, - output chan<- Chunk, - ) (string, error) + params RequestParameters, + messages []Message, + chunks chan<- Chunk, + ) (*Message, error) } diff --git a/pkg/api/conversation.go b/pkg/api/conversation.go new file mode 100644 index 0000000..1ee1064 --- /dev/null +++ b/pkg/api/conversation.go @@ -0,0 +1,11 @@ +package api + +import "database/sql" + +type Conversation struct { + ID uint `gorm:"primaryKey"` + ShortName sql.NullString + Title string + SelectedRootID *uint + SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` +} diff --git a/pkg/lmcli/model/conversation.go b/pkg/api/message.go similarity index 79% rename from pkg/lmcli/model/conversation.go rename to pkg/api/message.go index c817706..cc7cec1 100644 --- a/pkg/lmcli/model/conversation.go +++ b/pkg/api/message.go @@ -1,7 +1,6 @@ -package model +package api import ( - "database/sql" "time" ) @@ -32,24 +31,6 @@ type Message struct { SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` } -type Conversation struct { - ID uint `gorm:"primaryKey"` - ShortName sql.NullString - Title string - SelectedRootID *uint - SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` -} - -type RequestParameters struct { - Model string - - MaxTokens int - Temperature float32 - TopP float32 - - ToolBag []Tool -} - func (m *MessageRole) IsAssistant() bool { switch *m { case MessageRoleAssistant, MessageRoleToolCall: diff --git a/pkg/api/provider/anthropic/anthropic.go b/pkg/api/provider/anthropic/anthropic.go index a7426da..b70cf87 100644 --- a/pkg/api/provider/anthropic/anthropic.go +++ b/pkg/api/provider/anthropic/anthropic.go @@ -11,11 +11,9 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) -func buildRequest(params model.RequestParameters, messages []model.Message) Request { +func buildRequest(params api.RequestParameters, messages []api.Message) Request { requestBody := Request{ Model: params.Model, Messages: make([]Message, len(messages)), @@ -30,7 +28,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ } startIdx := 0 - if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { + if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem { requestBody.System = messages[0].Content requestBody.Messages = requestBody.Messages[1:] startIdx = 1 @@ -48,7 +46,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ message := &requestBody.Messages[i] switch msg.Role { - case model.MessageRoleToolCall: + case api.MessageRoleToolCall: message.Role = "assistant" if msg.Content != "" { message.Content = msg.Content @@ -63,7 +61,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ } else { message.Content = xmlString } - case model.MessageRoleToolResult: + case api.MessageRoleToolResult: xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlString, err := xmlFuncResults.XMLString() if err != nil { @@ -105,26 +103,25 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, -) (string, error) { + params api.RequestParameters, + messages []api.Message, +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } request := buildRequest(params, messages) resp, err := sendRequest(ctx, c, request) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() var response Response err = json.NewDecoder(resp.Body).Decode(&response) if err != nil { - return "", fmt.Errorf("failed to decode response: %v", err) + return nil, fmt.Errorf("failed to decode response: %v", err) } sb := strings.Builder{} @@ -137,34 +134,28 @@ func (c *AnthropicClient) CreateChatCompletion( } for _, content := range response.Content { - var reply model.Message switch content.Type { case "text": - reply = model.Message{ - Role: model.MessageRoleAssistant, - Content: content.Text, - } - sb.WriteString(reply.Content) + sb.WriteString(content.Text) default: - return "", fmt.Errorf("unsupported message type: %s", content.Type) - } - if callback != nil { - callback(reply) + return nil, fmt.Errorf("unsupported message type: %s", content.Type) } } - return sb.String(), nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: sb.String(), + }, nil } func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, + params api.RequestParameters, + messages []api.Message, output chan<- api.Chunk, -) (string, error) { +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } request := buildRequest(params, messages) @@ -172,19 +163,18 @@ func (c *AnthropicClient) CreateChatCompletionStream( resp, err := sendRequest(ctx, c, request) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() sb := strings.Builder{} lastMessage := messages[len(messages)-1] - continuation := false if messages[len(messages)-1].Role.IsAssistant() { // this is a continuation of a previous assistant reply, so we'll // include its contents in the final result + // TODO: handle this at higher level sb.WriteString(lastMessage.Content) - continuation = true } scanner := bufio.NewScanner(resp.Body) @@ -200,29 +190,29 @@ func (c *AnthropicClient) CreateChatCompletionStream( var event map[string]interface{} err := json.Unmarshal([]byte(line), &event) if err != nil { - return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) + return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) } eventType, ok := event["type"].(string) if !ok { - return "", fmt.Errorf("invalid event: %s", line) + return nil, fmt.Errorf("invalid event: %s", line) } switch eventType { case "error": - return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) + return nil, fmt.Errorf("an error occurred: %s", event["error"]) default: - return sb.String(), fmt.Errorf("unknown event type: %s", eventType) + return nil, fmt.Errorf("unknown event type: %s", eventType) } } else if strings.HasPrefix(line, "data:") { data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) var event map[string]interface{} err := json.Unmarshal([]byte(data), &event) if err != nil { - return "", fmt.Errorf("failed to unmarshal event data: %v", err) + return nil, fmt.Errorf("failed to unmarshal event data: %v", err) } eventType, ok := event["type"].(string) if !ok { - return "", fmt.Errorf("invalid event type") + return nil, fmt.Errorf("invalid event type") } switch eventType { @@ -235,15 +225,15 @@ func (c *AnthropicClient) CreateChatCompletionStream( case "content_block_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { - return "", fmt.Errorf("invalid content block delta") + return nil, fmt.Errorf("invalid content block delta") } text, ok := delta["text"].(string) if !ok { - return "", fmt.Errorf("invalid text delta") + return nil, fmt.Errorf("invalid text delta") } sb.WriteString(text) output <- api.Chunk{ - Content: text, + Content: text, TokenCount: 1, } case "content_block_stop": @@ -251,7 +241,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( case "message_delta": delta, ok := event["delta"].(map[string]interface{}) if !ok { - return "", fmt.Errorf("invalid message delta") + return nil, fmt.Errorf("invalid message delta") } stopReason, ok := delta["stop_reason"].(string) if ok && stopReason == "stop_sequence" { @@ -261,67 +251,39 @@ func (c *AnthropicClient) CreateChatCompletionStream( start := strings.Index(content, "") if start == -1 { - return content, fmt.Errorf("reached stop sequence but no opening tag found") + return nil, fmt.Errorf("reached stop sequence but no opening tag found") } sb.WriteString(FUNCTION_STOP_SEQUENCE) output <- api.Chunk{ - Content: FUNCTION_STOP_SEQUENCE, + Content: FUNCTION_STOP_SEQUENCE, TokenCount: 1, } - funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE var functionCalls XMLFunctionCalls err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) if err != nil { - return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) + return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err) } - toolCall := model.Message{ - Role: model.MessageRoleToolCall, + return &api.Message{ + Role: api.MessageRoleToolCall, // function call xml stripped from content for model interop Content: strings.TrimSpace(content[:start]), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), - } - - toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) - if err != nil { - return "", err - } - - toolResult := model.Message{ - Role: model.MessageRoleToolResult, - ToolResults: toolResults, - } - - if callback != nil { - callback(toolCall) - callback(toolResult) - } - - if continuation { - messages[len(messages)-1] = toolCall - } else { - messages = append(messages, toolCall) - } - - messages = append(messages, toolResult) - return c.CreateChatCompletionStream(ctx, params, messages, callback, output) + }, nil } } case "message_stop": // return the completed message content := sb.String() - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content, - }) - } - return content, nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content, + }, nil case "error": - return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) + return nil, fmt.Errorf("an error occurred: %s", event["error"]) default: fmt.Printf("\nUnrecognized event: %s\n", data) } @@ -329,8 +291,8 @@ func (c *AnthropicClient) CreateChatCompletionStream( } if err := scanner.Err(); err != nil { - return "", fmt.Errorf("failed to read response body: %v", err) + return nil, fmt.Errorf("failed to read response body: %v", err) } - return "", fmt.Errorf("unexpected end of stream") + return nil, fmt.Errorf("unexpected end of stream") } diff --git a/pkg/api/provider/anthropic/tools.go b/pkg/api/provider/anthropic/tools.go index d2faa4d..2314aa3 100644 --- a/pkg/api/provider/anthropic/tools.go +++ b/pkg/api/provider/anthropic/tools.go @@ -6,7 +6,7 @@ import ( "strings" "text/template" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" ) const FUNCTION_STOP_SEQUENCE = "" @@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} { return ret } -func convertToolsToXMLTools(tools []model.Tool) XMLTools { +func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools { converted := make([]XMLToolDescription, len(tools)) for i, tool := range tools { converted[i].ToolName = tool.Name @@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []model.Tool) XMLTools { } } -func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall { - toolCalls := make([]model.ToolCall, len(functionCalls.Invoke)) +func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall { + toolCalls := make([]api.ToolCall, len(functionCalls.Invoke)) for i, invoke := range functionCalls.Invoke { toolCalls[i].Name = invoke.ToolName toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String) @@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model. return toolCalls } -func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls { +func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls { converted := make([]XMLFunctionInvoke, len(toolCalls)) for i, toolCall := range toolCalls { var params XMLFunctionInvokeParameters @@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionC } } -func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults { +func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults { converted := make([]XMLFunctionResult, len(toolResults)) for i, result := range toolResults { converted[i].ToolName = result.ToolName @@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFu } } -func buildToolsSystemPrompt(tools []model.Tool) string { +func buildToolsSystemPrompt(tools []api.ToolSpec) string { xmlTools := convertToolsToXMLTools(tools) xmlToolsString, err := xmlTools.XMLString() if err != nil { - panic("Could not serialize []model.Tool to XMLTools") + panic("Could not serialize []api.Tool to XMLTools") } return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER } diff --git a/pkg/api/provider/google/google.go b/pkg/api/provider/google/google.go index 06d6cba..7c38ffb 100644 --- a/pkg/api/provider/google/google.go +++ b/pkg/api/provider/google/google.go @@ -11,11 +11,9 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) -func convertTools(tools []model.Tool) []Tool { +func convertTools(tools []api.ToolSpec) []Tool { geminiTools := make([]Tool, len(tools)) for i, tool := range tools { params := make(map[string]ToolParameter) @@ -50,7 +48,7 @@ func convertTools(tools []model.Tool) []Tool { return geminiTools } -func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { +func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart { converted := make([]ContentPart, len(toolCalls)) for i, call := range toolCalls { args := make(map[string]string) @@ -65,8 +63,8 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { return converted } -func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall { - converted := make([]model.ToolCall, len(functionCalls)) +func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall { + converted := make([]api.ToolCall, len(functionCalls)) for i, call := range functionCalls { params := make(map[string]interface{}) for k, v := range call.Args { @@ -78,7 +76,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall { return converted } -func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) { +func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) { results := make([]FunctionResponse, len(toolResults)) for i, result := range toolResults { var obj interface{} @@ -95,14 +93,14 @@ func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionRespo } func createGenerateContentRequest( - params model.RequestParameters, - messages []model.Message, + params api.RequestParameters, + messages []api.Message, ) (*GenerateContentRequest, error) { requestContents := make([]Content, 0, len(messages)) startIdx := 0 var system string - if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { + if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem { system = messages[0].Content startIdx = 1 } @@ -135,9 +133,9 @@ func createGenerateContentRequest( default: var role string switch m.Role { - case model.MessageRoleAssistant: + case api.MessageRoleAssistant: role = "model" - case model.MessageRoleUser: + case api.MessageRoleUser: role = "user" } @@ -183,55 +181,14 @@ func createGenerateContentRequest( return request, nil } -func handleToolCalls( - params model.RequestParameters, - content string, - toolCalls []model.ToolCall, - callback api.ReplyCallback, - messages []model.Message, -) ([]model.Message, error) { - lastMessage := messages[len(messages)-1] - continuation := false - if lastMessage.Role.IsAssistant() { - continuation = true - } - - toolCall := model.Message{ - Role: model.MessageRoleToolCall, - Content: content, - ToolCalls: toolCalls, - } - - toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) - if err != nil { - return nil, err - } - - toolResult := model.Message{ - Role: model.MessageRoleToolResult, - ToolResults: toolResults, - } - - if callback != nil { - callback(toolCall) - callback(toolResult) - } - - if continuation { - messages[len(messages)-1] = toolCall - } else { - messages = append(messages, toolCall) - } - messages = append(messages, toolResult) - - return messages, nil -} - -func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { +func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") client := &http.Client{} - resp, err := client.Do(req.WithContext(ctx)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) @@ -243,42 +200,41 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Resp func (c *Client) CreateChatCompletion( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, -) (string, error) { + params api.RequestParameters, + messages []api.Message, +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { - return "", err + return nil, err } jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } url := fmt.Sprintf( "%s/v1beta/models/%s:generateContent?key=%s", c.BaseURL, params.Model, c.APIKey, ) - httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() var completionResp GenerateContentResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { - return "", err + return nil, err } choice := completionResp.Candidates[0] @@ -301,58 +257,50 @@ func (c *Client) CreateChatCompletion( } if len(toolCalls) > 0 { - messages, err := handleToolCalls( - params, content, convertToolCallToAPI(toolCalls), callback, messages, - ) - if err != nil { - return content, err - } - - return c.CreateChatCompletion(ctx, params, messages, callback) + return &api.Message{ + Role: api.MessageRoleToolCall, + Content: content, + ToolCalls: convertToolCallToAPI(toolCalls), + }, nil } - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content, - }) - } - - return content, nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content, + }, nil } func (c *Client) CreateChatCompletionStream( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, + params api.RequestParameters, + messages []api.Message, output chan<- api.Chunk, -) (string, error) { +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { - return "", err + return nil, err } jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } url := fmt.Sprintf( "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", c.BaseURL, params.Model, c.APIKey, ) - httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() @@ -374,7 +322,7 @@ func (c *Client) CreateChatCompletionStream( if err == io.EOF { break } - return "", err + return nil, err } line = bytes.TrimSpace(line) @@ -387,7 +335,7 @@ func (c *Client) CreateChatCompletionStream( var resp GenerateContentResponse err = json.Unmarshal(line, &resp) if err != nil { - return "", err + return nil, err } tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount @@ -409,21 +357,15 @@ func (c *Client) CreateChatCompletionStream( // If there are function calls, handle them and recurse if len(toolCalls) > 0 { - messages, err := handleToolCalls( - params, content.String(), convertToolCallToAPI(toolCalls), callback, messages, - ) - if err != nil { - return content.String(), err - } - return c.CreateChatCompletionStream(ctx, params, messages, callback, output) + return &api.Message{ + Role: api.MessageRoleToolCall, + Content: content.String(), + ToolCalls: convertToolCallToAPI(toolCalls), + }, nil } - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content.String(), - }) - } - - return content.String(), nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content.String(), + }, nil } diff --git a/pkg/api/provider/ollama/ollama.go b/pkg/api/provider/ollama/ollama.go index a7b9ca5..960c282 100644 --- a/pkg/api/provider/ollama/ollama.go +++ b/pkg/api/provider/ollama/ollama.go @@ -11,7 +11,6 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" ) type OllamaClient struct { @@ -43,8 +42,8 @@ type OllamaResponse struct { } func createOllamaRequest( - params model.RequestParameters, - messages []model.Message, + params api.RequestParameters, + messages []api.Message, ) OllamaRequest { requestMessages := make([]OllamaMessage, 0, len(messages)) @@ -64,11 +63,11 @@ func createOllamaRequest( return request } -func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { +func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") client := &http.Client{} - resp, err := client.Do(req.WithContext(ctx)) + resp, err := client.Do(req) if err != nil { return nil, err } @@ -83,12 +82,11 @@ func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*htt func (c *OllamaClient) CreateChatCompletion( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, -) (string, error) { + params api.RequestParameters, + messages []api.Message, +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req := createOllamaRequest(params, messages) @@ -96,46 +94,40 @@ func (c *OllamaClient) CreateChatCompletion( jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } - httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() var completionResp OllamaResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { - return "", err + return nil, err } - content := completionResp.Message.Content - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content, - }) - } - - return content, nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: completionResp.Message.Content, + }, nil } func (c *OllamaClient) CreateChatCompletionStream( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, + params api.RequestParameters, + messages []api.Message, output chan<- api.Chunk, -) (string, error) { +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req := createOllamaRequest(params, messages) @@ -143,17 +135,17 @@ func (c *OllamaClient) CreateChatCompletionStream( jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } - httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() @@ -166,7 +158,7 @@ func (c *OllamaClient) CreateChatCompletionStream( if err == io.EOF { break } - return "", err + return nil, err } line = bytes.TrimSpace(line) @@ -177,7 +169,7 @@ func (c *OllamaClient) CreateChatCompletionStream( var streamResp OllamaResponse err = json.Unmarshal(line, &streamResp) if err != nil { - return "", err + return nil, err } if len(streamResp.Message.Content) > 0 { @@ -189,12 +181,8 @@ func (c *OllamaClient) CreateChatCompletionStream( } } - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content.String(), - }) - } - - return content.String(), nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content.String(), + }, nil } diff --git a/pkg/api/provider/openai/openai.go b/pkg/api/provider/openai/openai.go index 9bd199c..e2b010b 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/api/provider/openai/openai.go @@ -11,11 +11,9 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" ) -func convertTools(tools []model.Tool) []Tool { +func convertTools(tools []api.ToolSpec) []Tool { openaiTools := make([]Tool, len(tools)) for i, tool := range tools { openaiTools[i].Type = "function" @@ -47,7 +45,7 @@ func convertTools(tools []model.Tool) []Tool { return openaiTools } -func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall { +func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall { converted := make([]ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].Type = "function" @@ -60,8 +58,8 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall { return converted } -func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall { - converted := make([]model.ToolCall, len(toolCalls)) +func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall { + converted := make([]api.ToolCall, len(toolCalls)) for i, call := range toolCalls { converted[i].ID = call.ID converted[i].Name = call.Function.Name @@ -71,8 +69,8 @@ func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall { } func createChatCompletionRequest( - params model.RequestParameters, - messages []model.Message, + params api.RequestParameters, + messages []api.Message, ) ChatCompletionRequest { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) @@ -117,56 +115,15 @@ func createChatCompletionRequest( return request } -func handleToolCalls( - params model.RequestParameters, - content string, - toolCalls []ToolCall, - callback api.ReplyCallback, - messages []model.Message, -) ([]model.Message, error) { - lastMessage := messages[len(messages)-1] - continuation := false - if lastMessage.Role.IsAssistant() { - continuation = true - } - - toolCall := model.Message{ - Role: model.MessageRoleToolCall, - Content: content, - ToolCalls: convertToolCallToAPI(toolCalls), - } - - toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag) - if err != nil { - return nil, err - } - - toolResult := model.Message{ - Role: model.MessageRoleToolResult, - ToolResults: toolResults, - } - - if callback != nil { - callback(toolCall) - callback(toolResult) - } - - if continuation { - messages[len(messages)-1] = toolCall - } else { - messages = append(messages, toolCall) - } - messages = append(messages, toolResult) - - return messages, nil -} - -func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { +func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.APIKey) client := &http.Client{} - resp, err := client.Do(req.WithContext(ctx)) + resp, err := client.Do(req) + if err != nil { + return nil, err + } if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) @@ -178,35 +135,34 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*htt func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, -) (string, error) { + params api.RequestParameters, + messages []api.Message, +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } - httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() var completionResp ChatCompletionResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { - return "", err + return nil, err } choice := completionResp.Choices[0] @@ -221,34 +177,27 @@ func (c *OpenAIClient) CreateChatCompletion( toolCalls := choice.Message.ToolCalls if len(toolCalls) > 0 { - messages, err := handleToolCalls(params, content, toolCalls, callback, messages) - if err != nil { - return content, err - } - - return c.CreateChatCompletion(ctx, params, messages, callback) + return &api.Message{ + Role: api.MessageRoleToolCall, + Content: content, + ToolCalls: convertToolCallToAPI(toolCalls), + }, nil } - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content, - }) - } - - // Return the user-facing message. - return content, nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content, + }, nil } func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, - params model.RequestParameters, - messages []model.Message, - callback api.ReplyCallback, + params api.RequestParameters, + messages []api.Message, output chan<- api.Chunk, -) (string, error) { +) (*api.Message, error) { if len(messages) == 0 { - return "", fmt.Errorf("Can't create completion from no messages") + return nil, fmt.Errorf("Can't create completion from no messages") } req := createChatCompletionRequest(params, messages) @@ -256,17 +205,17 @@ func (c *OpenAIClient) CreateChatCompletionStream( jsonData, err := json.Marshal(req) if err != nil { - return "", err + return nil, err } - httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) + httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) if err != nil { - return "", err + return nil, err } - resp, err := c.sendRequest(ctx, httpReq) + resp, err := c.sendRequest(httpReq) if err != nil { - return "", err + return nil, err } defer resp.Body.Close() @@ -285,7 +234,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( if err == io.EOF { break } - return "", err + return nil, err } line = bytes.TrimSpace(line) @@ -301,7 +250,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( var streamResp ChatCompletionStreamResponse err = json.Unmarshal(line, &streamResp) if err != nil { - return "", err + return nil, err } delta := streamResp.Choices[0].Delta @@ -309,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( // Construct streamed tool_call arguments for _, tc := range delta.ToolCalls { if tc.Index == nil { - return "", fmt.Errorf("Unexpected nil index for streamed tool call.") + return nil, fmt.Errorf("Unexpected nil index for streamed tool call.") } if len(toolCalls) <= *tc.Index { toolCalls = append(toolCalls, tc) @@ -328,21 +277,15 @@ func (c *OpenAIClient) CreateChatCompletionStream( } if len(toolCalls) > 0 { - messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) - if err != nil { - return content.String(), err - } - - // Recurse into CreateChatCompletionStream with the tool call replies - return c.CreateChatCompletionStream(ctx, params, messages, callback, output) - } else { - if callback != nil { - callback(model.Message{ - Role: model.MessageRoleAssistant, - Content: content.String(), - }) - } + return &api.Message{ + Role: api.MessageRoleToolCall, + Content: content.String(), + ToolCalls: convertToolCallToAPI(toolCalls), + }, nil } - return content.String(), nil + return &api.Message{ + Role: api.MessageRoleAssistant, + Content: content.String(), + }, nil } diff --git a/pkg/lmcli/model/tool.go b/pkg/api/tools.go similarity index 95% rename from pkg/lmcli/model/tool.go rename to pkg/api/tools.go index 00bed0e..f4e3c72 100644 --- a/pkg/lmcli/model/tool.go +++ b/pkg/api/tools.go @@ -1,4 +1,4 @@ -package model +package api import ( "database/sql/driver" @@ -6,11 +6,11 @@ import ( "fmt" ) -type Tool struct { +type ToolSpec struct { Name string Description string Parameters []ToolParameter - Impl func(*Tool, map[string]interface{}) (string, error) + Impl func(*ToolSpec, map[string]interface{}) (string, error) } type ToolParameter struct { @@ -27,6 +27,12 @@ type ToolCall struct { Parameters map[string]interface{} `json:"parameters" yaml:"parameters"` } +type ToolResult struct { + ToolCallID string `json:"toolCallID" yaml:"-"` + ToolName string `json:"toolName,omitempty" yaml:"tool"` + Result string `json:"result,omitempty" yaml:"result"` +} + type ToolCalls []ToolCall func (tc *ToolCalls) Scan(value any) (err error) { @@ -50,12 +56,6 @@ func (tc ToolCalls) Value() (driver.Value, error) { return string(jsonBytes), nil } -type ToolResult struct { - ToolCallID string `json:"toolCallID" yaml:"-"` - ToolName string `json:"toolName,omitempty" yaml:"tool"` - Result string `json:"result,omitempty" yaml:"result"` -} - type ToolResults []ToolResult func (tr *ToolResults) Scan(value any) (err error) { diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index 1ee4317..73503d0 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -4,9 +4,9 @@ import ( "fmt" "strings" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -36,7 +36,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { } lastMessage := &messages[len(messages)-1] - if lastMessage.Role != model.MessageRoleAssistant { + if lastMessage.Role != api.MessageRoleAssistant { return fmt.Errorf("the last message in the conversation is not an assistant message") } @@ -50,7 +50,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { } // Append the new response to the original message - lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") + lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ") // Update the original message err = ctx.Store.UpdateMessage(lastMessage) diff --git a/pkg/cmd/edit.go b/pkg/cmd/edit.go index c710a95..fe6dd28 100644 --- a/pkg/cmd/edit.go +++ b/pkg/cmd/edit.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -53,10 +53,10 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command { role, _ := cmd.Flags().GetString("role") if role != "" { - if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { + if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) { return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") } - toEdit.Role = model.MessageRole(role) + toEdit.Role = api.MessageRole(role) } // Update the message in-place diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index e5a74da..aca7a9e 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -20,19 +20,19 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - var messages []model.Message + var messages []api.Message // TODO: probably just make this part of the conversation system := ctx.GetSystemPrompt() if system != "" { - messages = append(messages, model.Message{ - Role: model.MessageRoleSystem, + messages = append(messages, api.Message{ + Role: api.MessageRoleSystem, Content: system, }) } - messages = append(messages, model.Message{ - Role: model.MessageRoleUser, + messages = append(messages, api.Message{ + Role: api.MessageRoleUser, Content: input, }) diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 13e5998..8664c91 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -20,19 +20,19 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - var messages []model.Message + var messages []api.Message // TODO: stop supplying system prompt as a message system := ctx.GetSystemPrompt() if system != "" { - messages = append(messages, model.Message{ - Role: model.MessageRoleSystem, + messages = append(messages, api.Message{ + Role: api.MessageRoleSystem, Content: system, }) } - messages = append(messages, model.Message{ - Role: model.MessageRoleUser, + messages = append(messages, api.Message{ + Role: api.MessageRoleUser, Content: input, }) diff --git a/pkg/cmd/remove.go b/pkg/cmd/remove.go index 12cec51..8079ffb 100644 --- a/pkg/cmd/remove.go +++ b/pkg/cmd/remove.go @@ -4,9 +4,9 @@ import ( "fmt" "strings" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -23,7 +23,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { - var toRemove []*model.Conversation + var toRemove []*api.Conversation for _, shortName := range args { conversation := cmdutil.LookupConversation(ctx, shortName) toRemove = append(toRemove, conversation) diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go index 8483285..6338566 100644 --- a/pkg/cmd/reply.go +++ b/pkg/cmd/reply.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -30,8 +30,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No reply was provided.") } - cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ - Role: model.MessageRoleUser, + cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{ + Role: api.MessageRoleUser, Content: reply, }) return nil diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go index 82375cc..d88dd87 100644 --- a/pkg/cmd/retry.go +++ b/pkg/cmd/retry.go @@ -3,9 +3,9 @@ package cmd import ( "fmt" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "github.com/spf13/cobra" ) @@ -43,11 +43,11 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { retryFromIdx := len(messages) - 1 - offset // decrease retryFromIdx until we hit a user message - for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser { + for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser { retryFromIdx-- } - if messages[retryFromIdx].Role != model.MessageRoleUser { + if messages[retryFromIdx].Role != api.MessageRoleUser { return fmt.Errorf("No user messages to retry") } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 8c30685..c4407fa 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -10,36 +10,36 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/util" "github.com/charmbracelet/lipgloss" ) // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. -func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { - content := make(chan api.Chunk) // receives the reponse from LLM - defer close(content) - - // render all content received over the channel - go ShowDelayedContent(content) - +func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) if err != nil { - return "", err + return nil, err } - requestParams := model.RequestParameters{ + requestParams := api.RequestParameters{ Model: m, MaxTokens: *ctx.Config.Defaults.MaxTokens, Temperature: *ctx.Config.Defaults.Temperature, ToolBag: ctx.EnabledTools, } - response, err := provider.CreateChatCompletionStream( - context.Background(), requestParams, messages, callback, content, + content := make(chan api.Chunk) + defer close(content) + + // render the content received over the channel + go ShowDelayedContent(content) + + reply, err := provider.CreateChatCompletionStream( + context.Background(), requestParams, messages, content, ) - if response != "" { + + if reply.Content != "" { // there was some content, so break to a new line after it fmt.Println() @@ -48,12 +48,12 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me err = nil } } - return response, err + return reply, err } // lookupConversation either returns the conversation found by the // short name or exits the program -func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation { +func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation { c, err := ctx.Store.ConversationByShortName(shortName) if err != nil { lmcli.Fatal("Could not lookup conversation: %v\n", err) @@ -64,7 +64,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio return c } -func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) { +func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) { c, err := ctx.Store.ConversationByShortName(shortName) if err != nil { return nil, fmt.Errorf("Could not lookup conversation: %v", err) @@ -75,7 +75,7 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat return c, nil } -func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { +func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) { messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) if err != nil { lmcli.Fatal("Could not load messages: %v\n", err) @@ -85,7 +85,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist // handleConversationReply handles sending messages to an existing // conversation, optionally persisting both the sent replies and responses. -func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) { +func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) { if to == nil { lmcli.Fatal("Can't prompt from an empty message.") } @@ -97,7 +97,7 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages . RenderConversation(ctx, append(existing, messages...), true) - var savedReplies []model.Message + var savedReplies []api.Message if persist && len(messages) > 0 { savedReplies, err = ctx.Store.Reply(to, messages...) if err != nil { @@ -106,15 +106,15 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages . } // render a message header with no contents - RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) + RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant})) - var lastSavedMessage *model.Message + var lastSavedMessage *api.Message lastSavedMessage = to if len(savedReplies) > 0 { lastSavedMessage = &savedReplies[len(savedReplies)-1] } - replyCallback := func(reply model.Message) { + replyCallback := func(reply api.Message) { if !persist { return } @@ -131,16 +131,16 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages . } } -func FormatForExternalPrompt(messages []model.Message, system bool) string { +func FormatForExternalPrompt(messages []api.Message, system bool) string { sb := strings.Builder{} for _, message := range messages { if message.Content == "" { continue } switch message.Role { - case model.MessageRoleAssistant, model.MessageRoleToolCall: + case api.MessageRoleAssistant, api.MessageRoleToolCall: sb.WriteString("Assistant:\n\n") - case model.MessageRoleUser: + case api.MessageRoleUser: sb.WriteString("User:\n\n") default: continue @@ -150,7 +150,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string { return sb.String() } -func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) { +func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) { const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. Example conversation: @@ -177,28 +177,32 @@ Example response: return "", err } - generateRequest := []model.Message{ + generateRequest := []api.Message{ { - Role: model.MessageRoleSystem, + Role: api.MessageRoleSystem, Content: systemPrompt, }, { - Role: model.MessageRoleUser, + Role: api.MessageRoleUser, Content: string(conversation), }, } - m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel) + m, provider, err := ctx.GetModelProvider( + *ctx.Config.Conversations.TitleGenerationModel, + ) if err != nil { return "", err } - requestParams := model.RequestParameters{ + requestParams := api.RequestParameters{ Model: m, MaxTokens: 25, } - response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) + response, err := provider.CreateChatCompletion( + context.Background(), requestParams, generateRequest, + ) if err != nil { return "", err } @@ -207,7 +211,7 @@ Example response: var jsonResponse struct { Title string `json:"title"` } - err = json.Unmarshal([]byte(response), &jsonResponse) + err = json.Unmarshal([]byte(response.Content), &jsonResponse) if err != nil { return "", err } @@ -272,7 +276,7 @@ func ShowDelayedContent(content <-chan api.Chunk) { // RenderConversation renders the given messages to TTY, with optional space // for a subsequent message. spaceForResponse controls how many '\n' characters // are printed immediately after the final message (1 if false, 2 if true) -func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) { +func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) { l := len(messages) for i, message := range messages { RenderMessage(ctx, &message) @@ -283,7 +287,7 @@ func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForRe } } -func RenderMessage(ctx *lmcli.Context, m *model.Message) { +func RenderMessage(ctx *lmcli.Context, m *api.Message) { var messageAge string if m.CreatedAt.IsZero() { messageAge = "now" @@ -295,11 +299,11 @@ func RenderMessage(ctx *lmcli.Context, m *model.Message) { headerStyle := lipgloss.NewStyle().Bold(true) switch m.Role { - case model.MessageRoleSystem: + case api.MessageRoleSystem: headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red - case model.MessageRoleUser: + case api.MessageRoleUser: headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green - case model.MessageRoleAssistant: + case api.MessageRoleAssistant: headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue } diff --git a/pkg/cmd/view.go b/pkg/cmd/view.go index 772e869..0da608a 100644 --- a/pkg/cmd/view.go +++ b/pkg/cmd/view.go @@ -20,7 +20,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command { } return nil }, - RunE: func(cmd *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index a841d84..78b8886 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -6,13 +6,12 @@ import ( "path/filepath" "strings" + "git.mlow.ca/mlow/lmcli/pkg/agent" "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" @@ -24,7 +23,7 @@ type Context struct { Store ConversationStore Chroma *tty.ChromaHighlighter - EnabledTools []model.Tool + EnabledTools []api.ToolSpec SystemPromptFile string } @@ -50,9 +49,9 @@ func NewContext() (*Context, error) { chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) - var enabledTools []model.Tool + var enabledTools []api.ToolSpec for _, toolName := range config.Tools.EnabledTools { - tool, ok := tools.AvailableTools[toolName] + tool, ok := agent.AvailableTools[toolName] if ok { enabledTools = append(enabledTools, tool) } @@ -79,7 +78,7 @@ func (c *Context) GetModels() (models []string) { return } -func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) { +func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { parts := strings.Split(model, "@") var provider string diff --git a/pkg/lmcli/store.go b/pkg/lmcli/store.go index 3e10a64..1fd815e 100644 --- a/pkg/lmcli/store.go +++ b/pkg/lmcli/store.go @@ -8,32 +8,32 @@ import ( "strings" "time" - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" sqids "github.com/sqids/sqids-go" "gorm.io/gorm" ) type ConversationStore interface { - ConversationByShortName(shortName string) (*model.Conversation, error) + ConversationByShortName(shortName string) (*api.Conversation, error) ConversationShortNameCompletions(search string) []string - RootMessages(conversationID uint) ([]model.Message, error) - LatestConversationMessages() ([]model.Message, error) + RootMessages(conversationID uint) ([]api.Message, error) + LatestConversationMessages() ([]api.Message, error) - StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) - UpdateConversation(conversation *model.Conversation) error - DeleteConversation(conversation *model.Conversation) error - CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) + StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) + UpdateConversation(conversation *api.Conversation) error + DeleteConversation(conversation *api.Conversation) error + CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) - MessageByID(messageID uint) (*model.Message, error) - MessageReplies(messageID uint) ([]model.Message, error) + MessageByID(messageID uint) (*api.Message, error) + MessageReplies(messageID uint) ([]api.Message, error) - UpdateMessage(message *model.Message) error - DeleteMessage(message *model.Message, prune bool) error - CloneBranch(toClone model.Message) (*model.Message, uint, error) - Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) + UpdateMessage(message *api.Message) error + DeleteMessage(message *api.Message, prune bool) error + CloneBranch(toClone api.Message) (*api.Message, uint, error) + Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) - PathToRoot(message *model.Message) ([]model.Message, error) - PathToLeaf(message *model.Message) ([]model.Message, error) + PathToRoot(message *api.Message) ([]api.Message, error) + PathToLeaf(message *api.Message) ([]api.Message, error) } type SQLStore struct { @@ -43,8 +43,8 @@ type SQLStore struct { func NewSQLStore(db *gorm.DB) (*SQLStore, error) { models := []any{ - &model.Conversation{}, - &model.Message{}, + &api.Conversation{}, + &api.Message{}, } for _, x := range models { @@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) { return &SQLStore{db, _sqids}, nil } -func (s *SQLStore) createConversation() (*model.Conversation, error) { +func (s *SQLStore) createConversation() (*api.Conversation, error) { // Create the new conversation - c := &model.Conversation{} + c := &api.Conversation{} err := s.db.Save(c).Error if err != nil { return nil, err @@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*model.Conversation, error) { return c, nil } -func (s *SQLStore) UpdateConversation(c *model.Conversation) error { +func (s *SQLStore) UpdateConversation(c *api.Conversation) error { if c == nil || c.ID == 0 { return fmt.Errorf("Conversation is nil or invalid (missing ID)") } return s.db.Updates(c).Error } -func (s *SQLStore) DeleteConversation(c *model.Conversation) error { +func (s *SQLStore) DeleteConversation(c *api.Conversation) error { // Delete messages first - err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error + err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error if err != nil { return err } return s.db.Delete(c).Error } -func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { +func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error { panic("Not yet implemented") //return s.db.Delete(&message).Error } -func (s *SQLStore) UpdateMessage(m *model.Message) error { +func (s *SQLStore) UpdateMessage(m *api.Message) error { if m == nil || m.ID == 0 { return fmt.Errorf("Message is nil or invalid (missing ID)") } @@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error { } func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { - var conversations []model.Conversation + var conversations []api.Conversation // ignore error for completions s.db.Find(&conversations) completions := make([]string, 0, len(conversations)) @@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { return completions } -func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) { +func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) { if shortName == "" { return nil, errors.New("shortName is empty") } - var conversation model.Conversation + var conversation api.Conversation err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error return &conversation, err } -func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) { - var rootMessages []model.Message +func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) { + var rootMessages []api.Message err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error if err != nil { return nil, err @@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) { return rootMessages, nil } -func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) { - var message model.Message +func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) { + var message api.Message err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error return &message, err } -func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) { - var replies []model.Message +func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) { + var replies []api.Message err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error return replies, err } // StartConversation starts a new conversation with the provided messages -func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) { +func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) { if len(messages) == 0 { return nil, nil, fmt.Errorf("Must provide at least 1 message") } @@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa if err != nil { return nil, nil, err } - messages = append([]model.Message{messages[0]}, newMessages...) + messages = append([]api.Message{messages[0]}, newMessages...) } return conversation, messages, nil } // CloneConversation clones the given conversation and all of its root meesages -func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) { +func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) { rootMessages, err := s.RootMessages(toClone.ID) if err != nil { return nil, 0, err @@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers } // Reply to a message with a series of messages (each following the next) -func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) { - var savedMessages []model.Message +func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) { + var savedMessages []api.Message err := s.db.Transaction(func(tx *gorm.DB) error { currentParent := to @@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model. // CloneBranch returns a deep clone of the given message and its replies, returning // a new message object. The new message will be attached to the same parent as // the messageToClone -func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) { +func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) { newMessage := messageToClone newMessage.ID = 0 newMessage.Replies = nil @@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui return &newMessage, replyCount, nil } -func fetchMessages(db *gorm.DB) ([]model.Message, error) { - var messages []model.Message +func fetchMessages(db *gorm.DB) ([]api.Message, error) { + var messages []api.Message if err := db.Preload("Conversation").Find(&messages).Error; err != nil { return nil, fmt.Errorf("Could not fetch messages: %v", err) } - messageMap := make(map[uint]model.Message) + messageMap := make(map[uint]api.Message) for i, message := range messages { messageMap[messages[i].ID] = message } // Create a map to store replies by their parent ID - repliesMap := make(map[uint][]model.Message) + repliesMap := make(map[uint][]api.Message) for i, message := range messages { if messages[i].ParentID != nil { repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) @@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) { // Assign replies, parent, and selected reply to each message for i := range messages { if replies, exists := repliesMap[messages[i].ID]; exists { - messages[i].Replies = make([]model.Message, len(replies)) + messages[i].Replies = make([]api.Message, len(replies)) for j, m := range replies { messages[i].Replies[j] = m } @@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) { return messages, nil } -func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) { - var messages []model.Message +func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) { + var messages []api.Message messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) if err != nil { return nil, err } // Create a map to store messages by their ID - messageMap := make(map[uint]*model.Message) + messageMap := make(map[uint]*api.Message) for i := range messages { messageMap[messages[i].ID] = &messages[i] } // Build the path - var path []model.Message + var path []api.Message nextID := &message.ID for { @@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message // PathToRoot traverses the provided message's Parent until reaching the tree // root and returns a slice of all messages traversed in chronological order // (starting with the root and ending with the message provided) -func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { +func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } - path, err := s.buildPath(message, func(m *model.Message) *uint { + path, err := s.buildPath(message, func(m *api.Message) *uint { return m.ParentID }) if err != nil { @@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { // PathToLeaf traverses the provided message's SelectedReply until reaching a // tree leaf and returns a slice of all messages traversed in chronological // order (starting with the message provided and ending with the leaf) -func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { +func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) { if message == nil || message.ID <= 0 { return nil, fmt.Errorf("Message is nil or has invalid ID") } - return s.buildPath(message, func(m *model.Message) *uint { + return s.buildPath(message, func(m *api.Message) *uint { return m.SelectedReplyID }) } -func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) { - var latestMessages []model.Message +func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) { + var latestMessages []api.Message - subQuery := s.db.Model(&model.Message{}). + subQuery := s.db.Model(&api.Message{}). Select("MAX(created_at) as max_created_at, conversation_id"). Group("conversation_id") - err := s.db.Model(&model.Message{}). + err := s.db.Model(&api.Message{}). Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). Group("messages.conversation_id"). Order("created_at DESC"). diff --git a/pkg/lmcli/tools/tools.go b/pkg/lmcli/tools/tools.go deleted file mode 100644 index bc02f55..0000000 --- a/pkg/lmcli/tools/tools.go +++ /dev/null @@ -1,48 +0,0 @@ -package tools - -import ( - "fmt" - - "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" -) - -var AvailableTools map[string]model.Tool = map[string]model.Tool{ - "dir_tree": DirTreeTool, - "read_dir": ReadDirTool, - "read_file": ReadFileTool, - "write_file": WriteFileTool, - "file_insert_lines": FileInsertLinesTool, - "file_replace_lines": FileReplaceLinesTool, -} - -func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) { - var toolResults []model.ToolResult - for _, toolCall := range toolCalls { - var tool *model.Tool - for _, available := range toolBag { - if available.Name == toolCall.Name { - tool = &available - break - } - } - if tool == nil { - return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name) - } - - // Execute the tool - result, err := tool.Impl(tool, toolCall.Parameters) - if err != nil { - // This can happen if the model missed or supplied invalid tool args - return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err) - } - - toolResult := model.ToolResult{ - ToolCallID: toolCall.ID, - ToolName: toolCall.Name, - Result: result, - } - - toolResults = append(toolResults, toolResult) - } - return toolResults, nil -} diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index f18a577..cdd5223 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -4,7 +4,6 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/spinner" @@ -16,37 +15,39 @@ import ( // custom tea.Msg types type ( - // sent on each chunk received from LLM - msgResponseChunk api.Chunk - // sent when response is finished being received - msgResponseEnd string - // a special case of common.MsgError that stops the response waiting animation - msgResponseError error - // sent on each completed reply - msgResponse models.Message // sent when a conversation is (re)loaded msgConversationLoaded struct { - conversation *models.Conversation - rootMessages []models.Message + conversation *api.Conversation + rootMessages []api.Message } // sent when a new conversation title generated msgConversationTitleGenerated string - // sent when a conversation's messages are laoded - msgMessagesLoaded []models.Message // sent when the conversation has been persisted, triggers a reload of contents msgConversationPersisted struct { isNew bool - conversation *models.Conversation - messages []models.Message + conversation *api.Conversation + messages []api.Message } + // sent when a conversation's messages are laoded + msgMessagesLoaded []api.Message + // a special case of common.MsgError that stops the response waiting animation + msgChatResponseError error + // sent on each chunk received from LLM + msgChatResponseChunk api.Chunk + // sent on each completed reply + msgChatResponse *api.Message + // sent when the response is canceled + msgChatResponseCanceled struct{} + // sent when results from a tool call are returned + msgToolResults []api.ToolResult // sent when the given message is made the new selected reply of its parent - msgSelectedReplyCycled *models.Message + msgSelectedReplyCycled *api.Message // sent when the given message is made the new selected root of the current conversation - msgSelectedRootCycled *models.Message + msgSelectedRootCycled *api.Message // sent when a message's contents are updated and saved - msgMessageUpdated *models.Message + msgMessageUpdated *api.Message // sent when a message is cloned, with the cloned message - msgMessageCloned *models.Message + msgMessageCloned *api.Message ) type focusState int @@ -77,14 +78,14 @@ type Model struct { // app state state state // current overall status of the view - conversation *models.Conversation - rootMessages []models.Message - messages []models.Message + conversation *api.Conversation + rootMessages []api.Message + messages []api.Message selectedMessage int editorTarget editorTarget stopSignal chan struct{} - replyChan chan models.Message - replyChunkChan chan api.Chunk + replyChan chan api.Message + chatReplyChunks chan api.Chunk persistence bool // whether we will save new messages in the conversation // ui state @@ -111,12 +112,12 @@ func Chat(shared shared.Shared) Model { Shared: shared, state: idle, - conversation: &models.Conversation{}, + conversation: &api.Conversation{}, persistence: true, - stopSignal: make(chan struct{}), - replyChan: make(chan models.Message), - replyChunkChan: make(chan api.Chunk), + stopSignal: make(chan struct{}), + replyChan: make(chan api.Message), + chatReplyChunks: make(chan api.Chunk), wrap: true, selectedMessage: -1, @@ -144,8 +145,8 @@ func Chat(shared shared.Shared) Model { system := shared.Ctx.GetSystemPrompt() if system != "" { - m.messages = []models.Message{{ - Role: models.MessageRoleSystem, + m.messages = []api.Message{{ + Role: api.MessageRoleSystem, Content: system, }} } @@ -166,6 +167,5 @@ func Chat(shared shared.Shared) Model { func (m Model) Init() tea.Cmd { return tea.Batch( m.waitForResponseChunk(), - m.waitForResponse(), ) } diff --git a/pkg/tui/views/chat/conversation.go b/pkg/tui/views/chat/conversation.go index 59f5220..850d5aa 100644 --- a/pkg/tui/views/chat/conversation.go +++ b/pkg/tui/views/chat/conversation.go @@ -2,16 +2,18 @@ package chat import ( "context" + "errors" "fmt" "time" + "git.mlow.ca/mlow/lmcli/pkg/agent" + "git.mlow.ca/mlow/lmcli/pkg/api" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tea "github.com/charmbracelet/bubbletea" ) -func (m *Model) setMessage(i int, msg models.Message) { +func (m *Model) setMessage(i int, msg api.Message) { if i >= len(m.messages) { panic("i out of range") } @@ -19,7 +21,7 @@ func (m *Model) setMessage(i int, msg models.Message) { m.messageCache[i] = m.renderMessage(i) } -func (m *Model) addMessage(msg models.Message) { +func (m *Model) addMessage(msg api.Message) { m.messages = append(m.messages, msg) m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1)) } @@ -88,7 +90,7 @@ func (m *Model) generateConversationTitle() tea.Cmd { } } -func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd { +func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd { return func() tea.Msg { err := m.Shared.Ctx.Store.UpdateConversation(conversation) if err != nil { @@ -101,7 +103,7 @@ func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.C // Clones the given message (and its descendents). If selected is true, updates // either its parent's SelectedReply or its conversation's SelectedRoot to // point to the new clone -func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd { +func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd { return func() tea.Msg { msg, _, err := m.Ctx.Store.CloneBranch(message) if err != nil { @@ -123,7 +125,7 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd { } } -func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { +func (m *Model) updateMessageContent(message *api.Message) tea.Cmd { return func() tea.Msg { err := m.Shared.Ctx.Store.UpdateMessage(message) if err != nil { @@ -133,7 +135,7 @@ func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { } } -func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) { +func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) { currentIndex := -1 for i, reply := range choices { if reply.ID == selected.ID { @@ -158,7 +160,7 @@ func cycleSelectedMessage(selected *models.Message, choices []models.Message, di return &choices[next], nil } -func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd { +func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd { if len(m.rootMessages) < 2 { return nil } @@ -178,7 +180,7 @@ func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDir } } -func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd { +func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd { if len(message.Replies) < 2 { return nil } @@ -218,15 +220,12 @@ func (m *Model) persistConversation() tea.Cmd { // else, we'll handle updating an existing conversation's messages for i := range messages { if messages[i].ID > 0 { - // message has an ID, update its contents + // message has an ID, update it err := m.Shared.Ctx.Store.UpdateMessage(&messages[i]) if err != nil { return shared.MsgError(err) } } else if i > 0 { - if messages[i].Content == "" { - continue - } // messages is new, so add it as a reply to previous message saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i]) if err != nil { @@ -243,13 +242,23 @@ func (m *Model) persistConversation() tea.Cmd { } } +func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd { + return func() tea.Msg { + results, err := agent.ExecuteToolCalls(toolCalls, m.Ctx.EnabledTools) + if err != nil { + return shared.MsgError(err) + } + return msgToolResults(results) + } +} + func (m *Model) promptLLM() tea.Cmd { m.state = pendingResponse m.replyCursor.Blink = false - m.tokenCount = 0 m.startTime = time.Now() m.elapsed = 0 + m.tokenCount = 0 return func() tea.Msg { model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model) @@ -257,36 +266,34 @@ func (m *Model) promptLLM() tea.Cmd { return shared.MsgError(err) } - requestParams := models.RequestParameters{ + requestParams := api.RequestParameters{ Model: model, MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens, Temperature: *m.Shared.Ctx.Config.Defaults.Temperature, ToolBag: m.Shared.Ctx.EnabledTools, } - replyHandler := func(msg models.Message) { - m.replyChan <- msg - } - ctx, cancel := context.WithCancel(context.Background()) - canceled := false go func() { select { case <-m.stopSignal: - canceled = true cancel() } }() resp, err := provider.CreateChatCompletionStream( - ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, + ctx, requestParams, m.messages, m.chatReplyChunks, ) - if err != nil && !canceled { - return msgResponseError(err) + if errors.Is(err, context.Canceled) { + return msgChatResponseCanceled(struct{}{}) } - return msgResponseEnd(resp) + if err != nil { + return msgChatResponseError(err) + } + + return msgChatResponse(resp) } } diff --git a/pkg/tui/views/chat/input.go b/pkg/tui/views/chat/input.go index ce73d88..e46190e 100644 --- a/pkg/tui/views/chat/input.go +++ b/pkg/tui/views/chat/input.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tea "github.com/charmbracelet/bubbletea" @@ -150,12 +150,12 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) { return true, nil } - if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser { + if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser { return true, shared.WrapError(fmt.Errorf("Can't reply to a user message")) } - m.addMessage(models.Message{ - Role: models.MessageRoleUser, + m.addMessage(api.Message{ + Role: api.MessageRoleUser, Content: input, }) diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index e385fe3..9b58a3e 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -4,7 +4,7 @@ import ( "strings" "time" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "github.com/charmbracelet/bubbles/cursor" @@ -21,15 +21,9 @@ func (m *Model) HandleResize(width, height int) { } } -func (m *Model) waitForResponse() tea.Cmd { - return func() tea.Msg { - return msgResponse(<-m.replyChan) - } -} - func (m *Model) waitForResponseChunk() tea.Cmd { return func() tea.Msg { - return msgResponseChunk(<-m.replyChunkChan) + return msgChatResponseChunk(<-m.chatReplyChunks) } } @@ -48,7 +42,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { if m.conversation.ShortName.String != m.Shared.Values.ConvShortname { // clear existing messages if we're loading a new conversation - m.messages = []models.Message{} + m.messages = []api.Message{} m.selectedMessage = 0 } } @@ -87,7 +81,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { } m.rebuildMessageCache() m.updateContent() - case msgResponseChunk: + case msgChatResponseChunk: cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk if msg.Content == "" { @@ -100,8 +94,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.setMessageContents(last, m.messages[last].Content+msg.Content) } else { // use chunk in new message - m.addMessage(models.Message{ - Role: models.MessageRoleAssistant, + m.addMessage(api.Message{ + Role: api.MessageRoleAssistant, Content: msg.Content, }) } @@ -113,10 +107,10 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.tokenCount += msg.TokenCount m.elapsed = time.Now().Sub(m.startTime) - case msgResponse: - cmds = append(cmds, m.waitForResponse()) // wait for the next response + case msgChatResponse: + m.state = idle - reply := models.Message(msg) + reply := (*api.Message)(msg) reply.Content = strings.TrimSpace(reply.Content) last := len(m.messages) - 1 @@ -124,11 +118,18 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { panic("Unexpected empty messages handling msgAssistantReply") } - if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() { - // this was a continuation, so replace the previous message with the completed reply - m.setMessage(last, reply) + if m.messages[last].Role.IsAssistant() { + // TODO: handle continuations gracefully - some models support them well, others fail horribly. + m.setMessage(last, *reply) } else { - m.addMessage(reply) + m.addMessage(*reply) + } + + switch reply.Role { + case api.MessageRoleToolCall: + // TODO: user confirmation before execution + // m.state = waitingForConfirmation + cmds = append(cmds, m.executeToolCalls(reply.ToolCalls)) } if m.persistence { @@ -140,17 +141,32 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { } m.updateContent() - case msgResponseEnd: + case msgChatResponseCanceled: m.state = idle - last := len(m.messages) - 1 - if last < 0 { - panic("Unexpected empty messages handling msgResponseEnd") - } - m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content)) m.updateContent() - case msgResponseError: + case msgChatResponseError: m.state = idle m.Shared.Err = error(msg) + m.updateContent() + case msgToolResults: + last := len(m.messages) - 1 + if last < 0 { + panic("Unexpected empty messages handling msgAssistantReply") + } + + if m.messages[last].Role != api.MessageRoleToolCall { + panic("Previous message not a tool call, unexpected") + } + + m.addMessage(api.Message{ + Role: api.MessageRoleToolResult, + ToolResults: api.ToolResults(msg), + }) + + if m.persistence { + cmds = append(cmds, m.persistConversation()) + } + m.updateContent() case msgConversationTitleGenerated: title := string(msg) @@ -167,7 +183,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.conversation = msg.conversation m.messages = msg.messages if msg.isNew { - m.rootMessages = []models.Message{m.messages[0]} + m.rootMessages = []api.Message{m.messages[0]} } m.rebuildMessageCache() m.updateContent() diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index 028300f..40a1463 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/tui/styles" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" "github.com/charmbracelet/lipgloss" @@ -63,22 +63,22 @@ func (m Model) View() string { return lipgloss.JoinVertical(lipgloss.Left, sections...) } -func (m *Model) renderMessageHeading(i int, message *models.Message) string { +func (m *Model) renderMessageHeading(i int, message *api.Message) string { icon := "" friendly := message.Role.FriendlyRole() style := lipgloss.NewStyle().Faint(true).Bold(true) switch message.Role { - case models.MessageRoleSystem: + case api.MessageRoleSystem: icon = "⚙️" - case models.MessageRoleUser: + case api.MessageRoleUser: style = userStyle - case models.MessageRoleAssistant: + case api.MessageRoleAssistant: style = assistantStyle - case models.MessageRoleToolCall: + case api.MessageRoleToolCall: style = assistantStyle - friendly = models.MessageRoleAssistant.FriendlyRole() - case models.MessageRoleToolResult: + friendly = api.MessageRoleAssistant.FriendlyRole() + case api.MessageRoleToolResult: icon = "🔧" } @@ -139,21 +139,21 @@ func (m *Model) renderMessage(i int) string { } // Show the assistant's cursor - if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant { + if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == api.MessageRoleAssistant { sb.WriteString(m.replyCursor.View()) } // Write tool call info var toolString string switch msg.Role { - case models.MessageRoleToolCall: + case api.MessageRoleToolCall: bytes, err := yaml.Marshal(msg.ToolCalls) if err != nil { toolString = "Could not serialize ToolCalls" } else { toolString = "tool_calls:\n" + string(bytes) } - case models.MessageRoleToolResult: + case api.MessageRoleToolResult: if !m.showToolResults { break } @@ -221,11 +221,11 @@ func (m *Model) conversationMessagesView() string { m.messageOffsets[i] = lineCnt switch message.Role { - case models.MessageRoleToolCall: + case api.MessageRoleToolCall: if !m.showToolResults && message.Content == "" { continue } - case models.MessageRoleToolResult: + case api.MessageRoleToolResult: if !m.showToolResults { continue } @@ -251,9 +251,9 @@ func (m *Model) conversationMessagesView() string { } // Render a placeholder for the incoming assistant reply - if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) { - heading := m.renderMessageHeading(-1, &models.Message{ - Role: models.MessageRoleAssistant, + if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant) { + heading := m.renderMessageHeading(-1, &api.Message{ + Role: api.MessageRoleAssistant, }) sb.WriteString(heading) sb.WriteString("\n") diff --git a/pkg/tui/views/conversations/conversations.go b/pkg/tui/views/conversations/conversations.go index 46cd09b..b5b5c18 100644 --- a/pkg/tui/views/conversations/conversations.go +++ b/pkg/tui/views/conversations/conversations.go @@ -5,7 +5,7 @@ import ( "strings" "time" - models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" + "git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/styles" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" @@ -16,15 +16,15 @@ import ( ) type loadedConversation struct { - conv models.Conversation - lastReply models.Message + conv api.Conversation + lastReply api.Message } type ( // sent when conversation list is loaded msgConversationsLoaded ([]loadedConversation) // sent when a conversation is selected - msgConversationSelected models.Conversation + msgConversationSelected api.Conversation ) type Model struct {