Package restructure and API changes, several fixes
- More emphasis on `api` package. It now holds database model structs from `lmcli/models` (which is now gone) as well as the tool spec, call, and result types. `tools.Tool` is now `api.ToolSpec`. `api.ChatCompletionClient` was renamed to `api.ChatCompletionProvider`. - Change ChatCompletion interface and implementations to no longer do automatic tool call recursion - they simply return a ToolCall message which the caller can decide what to do with (e.g. prompt for user confirmation before executing) - `api.ChatCompletionProvider` functions have had their ReplyCallback parameter removed, as now they only return a single reply. - Added a top-level `agent` package, moved the current built-in tools implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now `agent.ExecuteToolCalls`. - Fixed request context handling in openai, google, ollama (use `NewRequestWithContext`), cleaned up request cancellation in TUI - Fix tool call tui persistence bug (we were skipping message with empty content) - Now handle tool calling from TUI layer TODO: - Prompt users before executing tool calls - Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
parent
85a2abbbf3
commit
3fde58b77d
@ -1,4 +1,4 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -7,8 +7,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||||
@ -27,10 +27,10 @@ Example result:
|
|||||||
}
|
}
|
||||||
`
|
`
|
||||||
|
|
||||||
var DirTreeTool = model.Tool{
|
var DirTreeTool = api.ToolSpec{
|
||||||
Name: "dir_tree",
|
Name: "dir_tree",
|
||||||
Description: TREE_DESCRIPTION,
|
Description: TREE_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "relative_path",
|
Name: "relative_path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -42,7 +42,7 @@ var DirTreeTool = model.Tool{
|
|||||||
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
|
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
var relativeDir string
|
var relativeDir string
|
||||||
if tmp, ok := args["relative_path"]; ok {
|
if tmp, ok := args["relative_path"]; ok {
|
||||||
relativeDir, ok = tmp.(string)
|
relativeDir, ok = tmp.(string)
|
||||||
@ -76,25 +76,25 @@ var DirTreeTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func tree(path string, depth int) model.CallResult {
|
func tree(path string, depth int) api.CallResult {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "."
|
path = "."
|
||||||
}
|
}
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
var treeOutput strings.Builder
|
var treeOutput strings.Builder
|
||||||
treeOutput.WriteString(path + "\n")
|
treeOutput.WriteString(path + "\n")
|
||||||
err := buildTree(&treeOutput, path, "", depth)
|
err := buildTree(&treeOutput, path, "", depth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{
|
return api.CallResult{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: treeOutput.String()}
|
return api.CallResult{Result: treeOutput.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
@ -1,22 +1,22 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||||
|
|
||||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||||
|
|
||||||
var FileInsertLinesTool = model.Tool{
|
var FileInsertLinesTool = api.ToolSpec{
|
||||||
Name: "file_insert_lines",
|
Name: "file_insert_lines",
|
||||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -36,7 +36,7 @@ var FileInsertLinesTool = model.Tool{
|
|||||||
Required: true,
|
Required: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
@ -72,27 +72,27 @@ var FileInsertLinesTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileInsertLines(path string, position int, content string) model.CallResult {
|
func fileInsertLines(path string, position int, content string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the existing file's content
|
// Read the existing file's content
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
_, err = os.Create(path)
|
_, err = os.Create(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
}
|
}
|
||||||
data = []byte{}
|
data = []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if position < 1 {
|
if position < 1 {
|
||||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
return api.CallResult{Message: "start_line cannot be less than 1"}
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
lines := strings.Split(string(data), "\n")
|
||||||
@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) model.CallResult
|
|||||||
// Join the lines and write back to the file
|
// Join the lines and write back to the file
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: newContent}
|
return api.CallResult{Result: newContent}
|
||||||
}
|
}
|
@ -1,12 +1,12 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||||
@ -15,10 +15,10 @@ Useful for re-writing snippets/blocks of code or entire functions.
|
|||||||
|
|
||||||
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
||||||
|
|
||||||
var FileReplaceLinesTool = model.Tool{
|
var FileReplaceLinesTool = api.ToolSpec{
|
||||||
Name: "file_replace_lines",
|
Name: "file_replace_lines",
|
||||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -42,7 +42,7 @@ var FileReplaceLinesTool = model.Tool{
|
|||||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
@ -87,27 +87,27 @@ var FileReplaceLinesTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
func fileReplaceLines(path string, startLine int, endLine int, content string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the existing file's content
|
// Read the existing file's content
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
_, err = os.Create(path)
|
_, err = os.Create(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
}
|
}
|
||||||
data = []byte{}
|
data = []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if startLine < 1 {
|
if startLine < 1 {
|
||||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
return api.CallResult{Message: "start_line cannot be less than 1"}
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
lines := strings.Split(string(data), "\n")
|
||||||
@ -126,8 +126,8 @@ func fileReplaceLines(path string, startLine int, endLine int, content string) m
|
|||||||
// Join the lines and write back to the file
|
// Join the lines and write back to the file
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: newContent}
|
return api.CallResult{Result: newContent}
|
||||||
}
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -6,8 +6,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||||
@ -25,17 +25,17 @@ Example result:
|
|||||||
For files, size represents the size of the file, in bytes.
|
For files, size represents the size of the file, in bytes.
|
||||||
For directories, size represents the number of entries in that directory.`
|
For directories, size represents the number of entries in that directory.`
|
||||||
|
|
||||||
var ReadDirTool = model.Tool{
|
var ReadDirTool = api.ToolSpec{
|
||||||
Name: "read_dir",
|
Name: "read_dir",
|
||||||
Description: READ_DIR_DESCRIPTION,
|
Description: READ_DIR_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "relative_dir",
|
Name: "relative_dir",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "If set, read the contents of a directory relative to the current one.",
|
Description: "If set, read the contents of a directory relative to the current one.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
var relativeDir string
|
var relativeDir string
|
||||||
tmp, ok := args["relative_dir"]
|
tmp, ok := args["relative_dir"]
|
||||||
if ok {
|
if ok {
|
||||||
@ -53,18 +53,18 @@ var ReadDirTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDir(path string) model.CallResult {
|
func readDir(path string) api.CallResult {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "."
|
path = "."
|
||||||
}
|
}
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
files, err := os.ReadDir(path)
|
files, err := os.ReadDir(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{
|
return api.CallResult{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,5 +96,5 @@ func readDir(path string) model.CallResult {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: dirContents}
|
return api.CallResult{Result: dirContents}
|
||||||
}
|
}
|
@ -1,12 +1,12 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||||
@ -21,10 +21,10 @@ Example result:
|
|||||||
"result": "1\tthe contents\n2\tof the file\n"
|
"result": "1\tthe contents\n2\tof the file\n"
|
||||||
}`
|
}`
|
||||||
|
|
||||||
var ReadFileTool = model.Tool{
|
var ReadFileTool = api.ToolSpec{
|
||||||
Name: "read_file",
|
Name: "read_file",
|
||||||
Description: READ_FILE_DESCRIPTION,
|
Description: READ_FILE_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -33,7 +33,7 @@ var ReadFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||||
@ -51,14 +51,14 @@ var ReadFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func readFile(path string) model.CallResult {
|
func readFile(path string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
lines := strings.Split(string(data), "\n")
|
||||||
@ -67,7 +67,7 @@ func readFile(path string) model.CallResult {
|
|||||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{
|
return api.CallResult{
|
||||||
Result: content.String(),
|
Result: content.String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,11 +1,11 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||||
@ -15,10 +15,10 @@ Example result:
|
|||||||
"message": "success"
|
"message": "success"
|
||||||
}`
|
}`
|
||||||
|
|
||||||
var WriteFileTool = model.Tool{
|
var WriteFileTool = api.ToolSpec{
|
||||||
Name: "write_file",
|
Name: "write_file",
|
||||||
Description: WRITE_FILE_DESCRIPTION,
|
Description: WRITE_FILE_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -32,7 +32,7 @@ var WriteFileTool = model.Tool{
|
|||||||
Required: true,
|
Required: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||||
@ -58,14 +58,14 @@ var WriteFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeFile(path string, content string) model.CallResult {
|
func writeFile(path string, content string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
err := os.WriteFile(path, []byte(content), 0644)
|
err := os.WriteFile(path, []byte(content), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
return model.CallResult{}
|
return api.CallResult{}
|
||||||
}
|
}
|
48
pkg/agent/tools.go
Normal file
48
pkg/agent/tools.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agent/toolbox"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
|
||||||
|
"dir_tree": toolbox.DirTreeTool,
|
||||||
|
"read_dir": toolbox.ReadDirTool,
|
||||||
|
"read_file": toolbox.ReadFileTool,
|
||||||
|
"write_file": toolbox.WriteFileTool,
|
||||||
|
"file_insert_lines": toolbox.FileInsertLinesTool,
|
||||||
|
"file_replace_lines": toolbox.FileReplaceLinesTool,
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
|
||||||
|
var toolResults []api.ToolResult
|
||||||
|
for _, call := range calls {
|
||||||
|
var tool *api.ToolSpec
|
||||||
|
for i := range available {
|
||||||
|
if available[i].Name == call.Name {
|
||||||
|
tool = &available[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tool == nil {
|
||||||
|
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
result, err := tool.Impl(tool, call.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := api.ToolResult{
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
ToolName: call.Name,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, toolResult)
|
||||||
|
}
|
||||||
|
return toolResults, nil
|
||||||
|
}
|
@ -2,35 +2,41 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReplyCallback func(model.Message)
|
type ReplyCallback func(Message)
|
||||||
|
|
||||||
type Chunk struct {
|
type Chunk struct {
|
||||||
Content string
|
Content string
|
||||||
TokenCount uint
|
TokenCount uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
type RequestParameters struct {
|
||||||
|
Model string
|
||||||
|
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float32
|
||||||
|
TopP float32
|
||||||
|
|
||||||
|
ToolBag []ToolSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionProvider interface {
|
||||||
// CreateChatCompletion requests a response to the provided messages.
|
// CreateChatCompletion requests a response to the provided messages.
|
||||||
// Replies are appended to the given replies struct, and the
|
// Replies are appended to the given replies struct, and the
|
||||||
// complete user-facing response is returned as a string.
|
// complete user-facing response is returned as a string.
|
||||||
CreateChatCompletion(
|
CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params RequestParameters,
|
||||||
messages []model.Message,
|
messages []Message,
|
||||||
callback ReplyCallback,
|
) (*Message, error)
|
||||||
) (string, error)
|
|
||||||
|
|
||||||
// Like CreateChageCompletion, except the response is streamed via
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
// the output channel as it's received.
|
// the output channel as it's received.
|
||||||
CreateChatCompletionStream(
|
CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params RequestParameters,
|
||||||
messages []model.Message,
|
messages []Message,
|
||||||
callback ReplyCallback,
|
chunks chan<- Chunk,
|
||||||
output chan<- Chunk,
|
) (*Message, error)
|
||||||
) (string, error)
|
|
||||||
}
|
}
|
||||||
|
11
pkg/api/conversation.go
Normal file
11
pkg/api/conversation.go
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import "database/sql"
|
||||||
|
|
||||||
|
type Conversation struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ShortName sql.NullString
|
||||||
|
Title string
|
||||||
|
SelectedRootID *uint
|
||||||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package model
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,24 +31,6 @@ type Message struct {
|
|||||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ShortName sql.NullString
|
|
||||||
Title string
|
|
||||||
SelectedRootID *uint
|
|
||||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestParameters struct {
|
|
||||||
Model string
|
|
||||||
|
|
||||||
MaxTokens int
|
|
||||||
Temperature float32
|
|
||||||
TopP float32
|
|
||||||
|
|
||||||
ToolBag []Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MessageRole) IsAssistant() bool {
|
func (m *MessageRole) IsAssistant() bool {
|
||||||
switch *m {
|
switch *m {
|
||||||
case MessageRoleAssistant, MessageRoleToolCall:
|
case MessageRoleAssistant, MessageRoleToolCall:
|
@ -11,11 +11,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
|
||||||
requestBody := Request{
|
requestBody := Request{
|
||||||
Model: params.Model,
|
Model: params.Model,
|
||||||
Messages: make([]Message, len(messages)),
|
Messages: make([]Message, len(messages)),
|
||||||
@ -30,7 +28,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
}
|
}
|
||||||
|
|
||||||
startIdx := 0
|
startIdx := 0
|
||||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||||
requestBody.System = messages[0].Content
|
requestBody.System = messages[0].Content
|
||||||
requestBody.Messages = requestBody.Messages[1:]
|
requestBody.Messages = requestBody.Messages[1:]
|
||||||
startIdx = 1
|
startIdx = 1
|
||||||
@ -48,7 +46,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
message := &requestBody.Messages[i]
|
message := &requestBody.Messages[i]
|
||||||
|
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case model.MessageRoleToolCall:
|
case api.MessageRoleToolCall:
|
||||||
message.Role = "assistant"
|
message.Role = "assistant"
|
||||||
if msg.Content != "" {
|
if msg.Content != "" {
|
||||||
message.Content = msg.Content
|
message.Content = msg.Content
|
||||||
@ -63,7 +61,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
} else {
|
} else {
|
||||||
message.Content = xmlString
|
message.Content = xmlString
|
||||||
}
|
}
|
||||||
case model.MessageRoleToolResult:
|
case api.MessageRoleToolResult:
|
||||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||||
xmlString, err := xmlFuncResults.XMLString()
|
xmlString, err := xmlFuncResults.XMLString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -105,26 +103,25 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
|
|||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
) (*api.Message, error) {
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var response Response
|
var response Response
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
return nil, fmt.Errorf("failed to decode response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
@ -137,34 +134,28 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, content := range response.Content {
|
for _, content := range response.Content {
|
||||||
var reply model.Message
|
|
||||||
switch content.Type {
|
switch content.Type {
|
||||||
case "text":
|
case "text":
|
||||||
reply = model.Message{
|
sb.WriteString(content.Text)
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.Text,
|
|
||||||
}
|
|
||||||
sb.WriteString(reply.Content)
|
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
|
||||||
}
|
|
||||||
if callback != nil {
|
|
||||||
callback(reply)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: sb.String(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
|
||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (string, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
@ -172,19 +163,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
|
|
||||||
lastMessage := messages[len(messages)-1]
|
lastMessage := messages[len(messages)-1]
|
||||||
continuation := false
|
|
||||||
if messages[len(messages)-1].Role.IsAssistant() {
|
if messages[len(messages)-1].Role.IsAssistant() {
|
||||||
// this is a continuation of a previous assistant reply, so we'll
|
// this is a continuation of a previous assistant reply, so we'll
|
||||||
// include its contents in the final result
|
// include its contents in the final result
|
||||||
|
// TODO: handle this at higher level
|
||||||
sb.WriteString(lastMessage.Content)
|
sb.WriteString(lastMessage.Content)
|
||||||
continuation = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@ -200,29 +190,29 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
var event map[string]interface{}
|
var event map[string]interface{}
|
||||||
err := json.Unmarshal([]byte(line), &event)
|
err := json.Unmarshal([]byte(line), &event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||||
}
|
}
|
||||||
eventType, ok := event["type"].(string)
|
eventType, ok := event["type"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("invalid event: %s", line)
|
return nil, fmt.Errorf("invalid event: %s", line)
|
||||||
}
|
}
|
||||||
switch eventType {
|
switch eventType {
|
||||||
case "error":
|
case "error":
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
default:
|
default:
|
||||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
return nil, fmt.Errorf("unknown event type: %s", eventType)
|
||||||
}
|
}
|
||||||
} else if strings.HasPrefix(line, "data:") {
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
var event map[string]interface{}
|
var event map[string]interface{}
|
||||||
err := json.Unmarshal([]byte(data), &event)
|
err := json.Unmarshal([]byte(data), &event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
eventType, ok := event["type"].(string)
|
eventType, ok := event["type"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("invalid event type")
|
return nil, fmt.Errorf("invalid event type")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch eventType {
|
switch eventType {
|
||||||
@ -235,11 +225,11 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("invalid content block delta")
|
return nil, fmt.Errorf("invalid content block delta")
|
||||||
}
|
}
|
||||||
text, ok := delta["text"].(string)
|
text, ok := delta["text"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("invalid text delta")
|
return nil, fmt.Errorf("invalid text delta")
|
||||||
}
|
}
|
||||||
sb.WriteString(text)
|
sb.WriteString(text)
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
@ -251,7 +241,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
case "message_delta":
|
case "message_delta":
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("invalid message delta")
|
return nil, fmt.Errorf("invalid message delta")
|
||||||
}
|
}
|
||||||
stopReason, ok := delta["stop_reason"].(string)
|
stopReason, ok := delta["stop_reason"].(string)
|
||||||
if ok && stopReason == "stop_sequence" {
|
if ok && stopReason == "stop_sequence" {
|
||||||
@ -261,7 +251,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
start := strings.Index(content, "<function_calls>")
|
start := strings.Index(content, "<function_calls>")
|
||||||
if start == -1 {
|
if start == -1 {
|
||||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||||
}
|
}
|
||||||
|
|
||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
@ -269,59 +259,31 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
Content: FUNCTION_STOP_SEQUENCE,
|
Content: FUNCTION_STOP_SEQUENCE,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
var functionCalls XMLFunctionCalls
|
var functionCalls XMLFunctionCalls
|
||||||
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCall := model.Message{
|
return &api.Message{
|
||||||
Role: model.MessageRoleToolCall,
|
Role: api.MessageRoleToolCall,
|
||||||
// function call xml stripped from content for model interop
|
// function call xml stripped from content for model interop
|
||||||
Content: strings.TrimSpace(content[:start]),
|
Content: strings.TrimSpace(content[:start]),
|
||||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(toolCall)
|
|
||||||
callback(toolResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
if continuation {
|
|
||||||
messages[len(messages)-1] = toolCall
|
|
||||||
} else {
|
|
||||||
messages = append(messages, toolCall)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = append(messages, toolResult)
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// return the completed message
|
// return the completed message
|
||||||
content := sb.String()
|
content := sb.String()
|
||||||
if callback != nil {
|
return &api.Message{
|
||||||
callback(model.Message{
|
Role: api.MessageRoleAssistant,
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content,
|
Content: content,
|
||||||
})
|
}, nil
|
||||||
}
|
|
||||||
return content, nil
|
|
||||||
case "error":
|
case "error":
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
default:
|
default:
|
||||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||||
}
|
}
|
||||||
@ -329,8 +291,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", fmt.Errorf("unexpected end of stream")
|
return nil, fmt.Errorf("unexpected end of stream")
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
|
||||||
converted := make([]XMLToolDescription, len(tools))
|
converted := make([]XMLToolDescription, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
converted[i].ToolName = tool.Name
|
converted[i].ToolName = tool.Name
|
||||||
@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
|
||||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
|
||||||
for i, invoke := range functionCalls.Invoke {
|
for i, invoke := range functionCalls.Invoke {
|
||||||
toolCalls[i].Name = invoke.ToolName
|
toolCalls[i].Name = invoke.ToolName
|
||||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||||
@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
|
|||||||
return toolCalls
|
return toolCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
|
||||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||||
for i, toolCall := range toolCalls {
|
for i, toolCall := range toolCalls {
|
||||||
var params XMLFunctionInvokeParameters
|
var params XMLFunctionInvokeParameters
|
||||||
@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionC
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
|
||||||
converted := make([]XMLFunctionResult, len(toolResults))
|
converted := make([]XMLFunctionResult, len(toolResults))
|
||||||
for i, result := range toolResults {
|
for i, result := range toolResults {
|
||||||
converted[i].ToolName = result.ToolName
|
converted[i].ToolName = result.ToolName
|
||||||
@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
|
||||||
xmlTools := convertToolsToXMLTools(tools)
|
xmlTools := convertToolsToXMLTools(tools)
|
||||||
xmlToolsString, err := xmlTools.XMLString()
|
xmlToolsString, err := xmlTools.XMLString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Could not serialize []model.Tool to XMLTools")
|
panic("Could not serialize []api.Tool to XMLTools")
|
||||||
}
|
}
|
||||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||||
}
|
}
|
||||||
|
@ -11,11 +11,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
geminiTools := make([]Tool, len(tools))
|
geminiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
params := make(map[string]ToolParameter)
|
params := make(map[string]ToolParameter)
|
||||||
@ -50,7 +48,7 @@ func convertTools(tools []model.Tool) []Tool {
|
|||||||
return geminiTools
|
return geminiTools
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
||||||
converted := make([]ContentPart, len(toolCalls))
|
converted := make([]ContentPart, len(toolCalls))
|
||||||
for i, call := range toolCalls {
|
for i, call := range toolCalls {
|
||||||
args := make(map[string]string)
|
args := make(map[string]string)
|
||||||
@ -65,8 +63,8 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
||||||
converted := make([]model.ToolCall, len(functionCalls))
|
converted := make([]api.ToolCall, len(functionCalls))
|
||||||
for i, call := range functionCalls {
|
for i, call := range functionCalls {
|
||||||
params := make(map[string]interface{})
|
params := make(map[string]interface{})
|
||||||
for k, v := range call.Args {
|
for k, v := range call.Args {
|
||||||
@ -78,7 +76,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
|
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
||||||
results := make([]FunctionResponse, len(toolResults))
|
results := make([]FunctionResponse, len(toolResults))
|
||||||
for i, result := range toolResults {
|
for i, result := range toolResults {
|
||||||
var obj interface{}
|
var obj interface{}
|
||||||
@ -95,14 +93,14 @@ func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionRespo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createGenerateContentRequest(
|
func createGenerateContentRequest(
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
) (*GenerateContentRequest, error) {
|
) (*GenerateContentRequest, error) {
|
||||||
requestContents := make([]Content, 0, len(messages))
|
requestContents := make([]Content, 0, len(messages))
|
||||||
|
|
||||||
startIdx := 0
|
startIdx := 0
|
||||||
var system string
|
var system string
|
||||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||||
system = messages[0].Content
|
system = messages[0].Content
|
||||||
startIdx = 1
|
startIdx = 1
|
||||||
}
|
}
|
||||||
@ -135,9 +133,9 @@ func createGenerateContentRequest(
|
|||||||
default:
|
default:
|
||||||
var role string
|
var role string
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case model.MessageRoleAssistant:
|
case api.MessageRoleAssistant:
|
||||||
role = "model"
|
role = "model"
|
||||||
case model.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
role = "user"
|
role = "user"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,55 +181,14 @@ func createGenerateContentRequest(
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleToolCalls(
|
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||||
params model.RequestParameters,
|
|
||||||
content string,
|
|
||||||
toolCalls []model.ToolCall,
|
|
||||||
callback api.ReplyCallback,
|
|
||||||
messages []model.Message,
|
|
||||||
) ([]model.Message, error) {
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
continuation := false
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
continuation = true
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCall := model.Message{
|
|
||||||
Role: model.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: toolCalls,
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(toolCall)
|
|
||||||
callback(toolResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
if continuation {
|
|
||||||
messages[len(messages)-1] = toolCall
|
|
||||||
} else {
|
|
||||||
messages = append(messages, toolCall)
|
|
||||||
}
|
|
||||||
messages = append(messages, toolResult)
|
|
||||||
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
bytes, _ := io.ReadAll(resp.Body)
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
@ -243,42 +200,41 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Resp
|
|||||||
|
|
||||||
func (c *Client) CreateChatCompletion(
|
func (c *Client) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
) (*api.Message, error) {
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := createGenerateContentRequest(params, messages)
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
url := fmt.Sprintf(
|
||||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||||
c.BaseURL, params.Model, c.APIKey,
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
)
|
)
|
||||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var completionResp GenerateContentResponse
|
var completionResp GenerateContentResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := completionResp.Candidates[0]
|
choice := completionResp.Candidates[0]
|
||||||
@ -301,58 +257,50 @@ func (c *Client) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(
|
return &api.Message{
|
||||||
params, content, convertToolCallToAPI(toolCalls), callback, messages,
|
Role: api.MessageRoleToolCall,
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return content, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content,
|
Content: content,
|
||||||
})
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return content, nil
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) CreateChatCompletionStream(
|
func (c *Client) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
|
||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (string, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := createGenerateContentRequest(params, messages)
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf(
|
url := fmt.Sprintf(
|
||||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||||
c.BaseURL, params.Model, c.APIKey,
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
)
|
)
|
||||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@ -374,7 +322,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
line = bytes.TrimSpace(line)
|
||||||
@ -387,7 +335,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
var resp GenerateContentResponse
|
var resp GenerateContentResponse
|
||||||
err = json.Unmarshal(line, &resp)
|
err = json.Unmarshal(line, &resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||||
@ -409,21 +357,15 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
|
|
||||||
// If there are function calls, handle them and recurse
|
// If there are function calls, handle them and recurse
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(
|
return &api.Message{
|
||||||
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
|
Role: api.MessageRoleToolCall,
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
Content: content.String(),
|
||||||
})
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return content.String(), nil
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaClient struct {
|
type OllamaClient struct {
|
||||||
@ -43,8 +42,8 @@ type OllamaResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createOllamaRequest(
|
func createOllamaRequest(
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
) OllamaRequest {
|
) OllamaRequest {
|
||||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||||
|
|
||||||
@ -64,11 +63,11 @@ func createOllamaRequest(
|
|||||||
return request
|
return request
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -83,12 +82,11 @@ func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*htt
|
|||||||
|
|
||||||
func (c *OllamaClient) CreateChatCompletion(
|
func (c *OllamaClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
) (*api.Message, error) {
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createOllamaRequest(params, messages)
|
req := createOllamaRequest(params, messages)
|
||||||
@ -96,46 +94,40 @@ func (c *OllamaClient) CreateChatCompletion(
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var completionResp OllamaResponse
|
var completionResp OllamaResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
content := completionResp.Message.Content
|
return &api.Message{
|
||||||
if callback != nil {
|
Role: api.MessageRoleAssistant,
|
||||||
callback(model.Message{
|
Content: completionResp.Message.Content,
|
||||||
Role: model.MessageRoleAssistant,
|
}, nil
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return content, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OllamaClient) CreateChatCompletionStream(
|
func (c *OllamaClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
|
||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (string, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createOllamaRequest(params, messages)
|
req := createOllamaRequest(params, messages)
|
||||||
@ -143,17 +135,17 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@ -166,7 +158,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
line = bytes.TrimSpace(line)
|
||||||
@ -177,7 +169,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
var streamResp OllamaResponse
|
var streamResp OllamaResponse
|
||||||
err = json.Unmarshal(line, &streamResp)
|
err = json.Unmarshal(line, &streamResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(streamResp.Message.Content) > 0 {
|
if len(streamResp.Message.Content) > 0 {
|
||||||
@ -189,12 +181,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if callback != nil {
|
return &api.Message{
|
||||||
callback(model.Message{
|
Role: api.MessageRoleAssistant,
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
Content: content.String(),
|
||||||
})
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
return content.String(), nil
|
|
||||||
}
|
}
|
||||||
|
@ -11,11 +11,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
openaiTools := make([]Tool, len(tools))
|
openaiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
openaiTools[i].Type = "function"
|
openaiTools[i].Type = "function"
|
||||||
@ -47,7 +45,7 @@ func convertTools(tools []model.Tool) []Tool {
|
|||||||
return openaiTools
|
return openaiTools
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
||||||
converted := make([]ToolCall, len(toolCalls))
|
converted := make([]ToolCall, len(toolCalls))
|
||||||
for i, call := range toolCalls {
|
for i, call := range toolCalls {
|
||||||
converted[i].Type = "function"
|
converted[i].Type = "function"
|
||||||
@ -60,8 +58,8 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
|
|||||||
return converted
|
return converted
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||||
converted := make([]model.ToolCall, len(toolCalls))
|
converted := make([]api.ToolCall, len(toolCalls))
|
||||||
for i, call := range toolCalls {
|
for i, call := range toolCalls {
|
||||||
converted[i].ID = call.ID
|
converted[i].ID = call.ID
|
||||||
converted[i].Name = call.Function.Name
|
converted[i].Name = call.Function.Name
|
||||||
@ -71,8 +69,8 @@ func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
func createChatCompletionRequest(
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
) ChatCompletionRequest {
|
) ChatCompletionRequest {
|
||||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
|
|
||||||
@ -117,56 +115,15 @@ func createChatCompletionRequest(
|
|||||||
return request
|
return request
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleToolCalls(
|
func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||||
params model.RequestParameters,
|
|
||||||
content string,
|
|
||||||
toolCalls []ToolCall,
|
|
||||||
callback api.ReplyCallback,
|
|
||||||
messages []model.Message,
|
|
||||||
) ([]model.Message, error) {
|
|
||||||
lastMessage := messages[len(messages)-1]
|
|
||||||
continuation := false
|
|
||||||
if lastMessage.Role.IsAssistant() {
|
|
||||||
continuation = true
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCall := model.Message{
|
|
||||||
Role: model.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(toolCall)
|
|
||||||
callback(toolResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
if continuation {
|
|
||||||
messages[len(messages)-1] = toolCall
|
|
||||||
} else {
|
|
||||||
messages = append(messages, toolCall)
|
|
||||||
}
|
|
||||||
messages = append(messages, toolResult)
|
|
||||||
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req.WithContext(ctx))
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
bytes, _ := io.ReadAll(resp.Body)
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
@ -178,35 +135,34 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*htt
|
|||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
) (*api.Message, error) {
|
||||||
) (string, error) {
|
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createChatCompletionRequest(params, messages)
|
req := createChatCompletionRequest(params, messages)
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var completionResp ChatCompletionResponse
|
var completionResp ChatCompletionResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
choice := completionResp.Choices[0]
|
choice := completionResp.Choices[0]
|
||||||
@ -221,34 +177,27 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
|
|
||||||
toolCalls := choice.Message.ToolCalls
|
toolCalls := choice.Message.ToolCalls
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
|
return &api.Message{
|
||||||
if err != nil {
|
Role: api.MessageRoleToolCall,
|
||||||
return content, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content,
|
Content: content,
|
||||||
})
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the user-facing message.
|
return &api.Message{
|
||||||
return content, nil
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params api.RequestParameters,
|
||||||
messages []model.Message,
|
messages []api.Message,
|
||||||
callback api.ReplyCallback,
|
|
||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (string, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return "", fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := createChatCompletionRequest(params, messages)
|
req := createChatCompletionRequest(params, messages)
|
||||||
@ -256,17 +205,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
jsonData, err := json.Marshal(req)
|
jsonData, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, httpReq)
|
resp, err := c.sendRequest(httpReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
@ -285,7 +234,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
line = bytes.TrimSpace(line)
|
||||||
@ -301,7 +250,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
var streamResp ChatCompletionStreamResponse
|
var streamResp ChatCompletionStreamResponse
|
||||||
err = json.Unmarshal(line, &streamResp)
|
err = json.Unmarshal(line, &streamResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
delta := streamResp.Choices[0].Delta
|
delta := streamResp.Choices[0].Delta
|
||||||
@ -309,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
// Construct streamed tool_call arguments
|
// Construct streamed tool_call arguments
|
||||||
for _, tc := range delta.ToolCalls {
|
for _, tc := range delta.ToolCalls {
|
||||||
if tc.Index == nil {
|
if tc.Index == nil {
|
||||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||||
}
|
}
|
||||||
if len(toolCalls) <= *tc.Index {
|
if len(toolCalls) <= *tc.Index {
|
||||||
toolCalls = append(toolCalls, tc)
|
toolCalls = append(toolCalls, tc)
|
||||||
@ -328,21 +277,15 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
|
return &api.Message{
|
||||||
if err != nil {
|
Role: api.MessageRoleToolCall,
|
||||||
return content.String(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
} else {
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
Content: content.String(),
|
||||||
})
|
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return content.String(), nil
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: content.String(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package model
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
@ -6,11 +6,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tool struct {
|
type ToolSpec struct {
|
||||||
Name string
|
Name string
|
||||||
Description string
|
Description string
|
||||||
Parameters []ToolParameter
|
Parameters []ToolParameter
|
||||||
Impl func(*Tool, map[string]interface{}) (string, error)
|
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolParameter struct {
|
type ToolParameter struct {
|
||||||
@ -27,6 +27,12 @@ type ToolCall struct {
|
|||||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolResult struct {
|
||||||
|
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||||
|
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||||
|
Result string `json:"result,omitempty" yaml:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
type ToolCalls []ToolCall
|
type ToolCalls []ToolCall
|
||||||
|
|
||||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||||
@ -50,12 +56,6 @@ func (tc ToolCalls) Value() (driver.Value, error) {
|
|||||||
return string(jsonBytes), nil
|
return string(jsonBytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolResult struct {
|
|
||||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
|
||||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
|
||||||
Result string `json:"result,omitempty" yaml:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResults []ToolResult
|
type ToolResults []ToolResult
|
||||||
|
|
||||||
func (tr *ToolResults) Scan(value any) (err error) {
|
func (tr *ToolResults) Scan(value any) (err error) {
|
@ -4,9 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lastMessage := &messages[len(messages)-1]
|
lastMessage := &messages[len(messages)-1]
|
||||||
if lastMessage.Role != model.MessageRoleAssistant {
|
if lastMessage.Role != api.MessageRoleAssistant {
|
||||||
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,7 +50,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append the new response to the original message
|
// Append the new response to the original message
|
||||||
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
|
||||||
|
|
||||||
// Update the original message
|
// Update the original message
|
||||||
err = ctx.Store.UpdateMessage(lastMessage)
|
err = ctx.Store.UpdateMessage(lastMessage)
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,10 +53,10 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
role, _ := cmd.Flags().GetString("role")
|
role, _ := cmd.Flags().GetString("role")
|
||||||
if role != "" {
|
if role != "" {
|
||||||
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) {
|
||||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
}
|
}
|
||||||
toEdit.Role = model.MessageRole(role)
|
toEdit.Role = api.MessageRole(role)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the message in-place
|
// Update the message in-place
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,19 +20,19 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []model.Message
|
var messages []api.Message
|
||||||
|
|
||||||
// TODO: probably just make this part of the conversation
|
// TODO: probably just make this part of the conversation
|
||||||
system := ctx.GetSystemPrompt()
|
system := ctx.GetSystemPrompt()
|
||||||
if system != "" {
|
if system != "" {
|
||||||
messages = append(messages, model.Message{
|
messages = append(messages, api.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: api.MessageRoleSystem,
|
||||||
Content: system,
|
Content: system,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, model.Message{
|
messages = append(messages, api.Message{
|
||||||
Role: model.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -20,19 +20,19 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []model.Message
|
var messages []api.Message
|
||||||
|
|
||||||
// TODO: stop supplying system prompt as a message
|
// TODO: stop supplying system prompt as a message
|
||||||
system := ctx.GetSystemPrompt()
|
system := ctx.GetSystemPrompt()
|
||||||
if system != "" {
|
if system != "" {
|
||||||
messages = append(messages, model.Message{
|
messages = append(messages, api.Message{
|
||||||
Role: model.MessageRoleSystem,
|
Role: api.MessageRoleSystem,
|
||||||
Content: system,
|
Content: system,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
messages = append(messages, model.Message{
|
messages = append(messages, api.Message{
|
||||||
Role: model.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -4,9 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
var toRemove []*model.Conversation
|
var toRemove []*api.Conversation
|
||||||
for _, shortName := range args {
|
for _, shortName := range args {
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
toRemove = append(toRemove, conversation)
|
toRemove = append(toRemove, conversation)
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,8 +30,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No reply was provided.")
|
return fmt.Errorf("No reply was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{
|
||||||
Role: model.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: reply,
|
Content: reply,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,11 +43,11 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
retryFromIdx := len(messages) - 1 - offset
|
retryFromIdx := len(messages) - 1 - offset
|
||||||
|
|
||||||
// decrease retryFromIdx until we hit a user message
|
// decrease retryFromIdx until we hit a user message
|
||||||
for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser {
|
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||||
retryFromIdx--
|
retryFromIdx--
|
||||||
}
|
}
|
||||||
|
|
||||||
if messages[retryFromIdx].Role != model.MessageRoleUser {
|
if messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||||
return fmt.Errorf("No user messages to retry")
|
return fmt.Errorf("No user messages to retry")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,36 +10,36 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
|
||||||
content := make(chan api.Chunk) // receives the reponse from LLM
|
|
||||||
defer close(content)
|
|
||||||
|
|
||||||
// render all content received over the channel
|
|
||||||
go ShowDelayedContent(content)
|
|
||||||
|
|
||||||
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
|
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
requestParams := api.RequestParameters{
|
||||||
Model: m,
|
Model: m,
|
||||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *ctx.Config.Defaults.Temperature,
|
Temperature: *ctx.Config.Defaults.Temperature,
|
||||||
ToolBag: ctx.EnabledTools,
|
ToolBag: ctx.EnabledTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := provider.CreateChatCompletionStream(
|
content := make(chan api.Chunk)
|
||||||
context.Background(), requestParams, messages, callback, content,
|
defer close(content)
|
||||||
|
|
||||||
|
// render the content received over the channel
|
||||||
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
|
reply, err := provider.CreateChatCompletionStream(
|
||||||
|
context.Background(), requestParams, messages, content,
|
||||||
)
|
)
|
||||||
if response != "" {
|
|
||||||
|
if reply.Content != "" {
|
||||||
// there was some content, so break to a new line after it
|
// there was some content, so break to a new line after it
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
@ -48,12 +48,12 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return response, err
|
return reply, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
// short name or exits the program
|
// short name or exits the program
|
||||||
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
@ -64,7 +64,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
@ -75,7 +75,7 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) {
|
||||||
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
@ -85,7 +85,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
|||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
|
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) {
|
||||||
if to == nil {
|
if to == nil {
|
||||||
lmcli.Fatal("Can't prompt from an empty message.")
|
lmcli.Fatal("Can't prompt from an empty message.")
|
||||||
}
|
}
|
||||||
@ -97,7 +97,7 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
|
|||||||
|
|
||||||
RenderConversation(ctx, append(existing, messages...), true)
|
RenderConversation(ctx, append(existing, messages...), true)
|
||||||
|
|
||||||
var savedReplies []model.Message
|
var savedReplies []api.Message
|
||||||
if persist && len(messages) > 0 {
|
if persist && len(messages) > 0 {
|
||||||
savedReplies, err = ctx.Store.Reply(to, messages...)
|
savedReplies, err = ctx.Store.Reply(to, messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -106,15 +106,15 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
|
|||||||
}
|
}
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant}))
|
||||||
|
|
||||||
var lastSavedMessage *model.Message
|
var lastSavedMessage *api.Message
|
||||||
lastSavedMessage = to
|
lastSavedMessage = to
|
||||||
if len(savedReplies) > 0 {
|
if len(savedReplies) > 0 {
|
||||||
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
||||||
}
|
}
|
||||||
|
|
||||||
replyCallback := func(reply model.Message) {
|
replyCallback := func(reply api.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -131,16 +131,16 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
func FormatForExternalPrompt(messages []api.Message, system bool) string {
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
if message.Content == "" {
|
if message.Content == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case model.MessageRoleAssistant, model.MessageRoleToolCall:
|
case api.MessageRoleAssistant, api.MessageRoleToolCall:
|
||||||
sb.WriteString("Assistant:\n\n")
|
sb.WriteString("Assistant:\n\n")
|
||||||
case model.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
sb.WriteString("User:\n\n")
|
sb.WriteString("User:\n\n")
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
@ -150,7 +150,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) {
|
||||||
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
@ -177,28 +177,32 @@ Example response:
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
generateRequest := []model.Message{
|
generateRequest := []api.Message{
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleSystem,
|
Role: api.MessageRoleSystem,
|
||||||
Content: systemPrompt,
|
Content: systemPrompt,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: string(conversation),
|
Content: string(conversation),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
m, provider, err := ctx.GetModelProvider(
|
||||||
|
*ctx.Config.Conversations.TitleGenerationModel,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
requestParams := api.RequestParameters{
|
||||||
Model: m,
|
Model: m,
|
||||||
MaxTokens: 25,
|
MaxTokens: 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
response, err := provider.CreateChatCompletion(
|
||||||
|
context.Background(), requestParams, generateRequest,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -207,7 +211,7 @@ Example response:
|
|||||||
var jsonResponse struct {
|
var jsonResponse struct {
|
||||||
Title string `json:"title"`
|
Title string `json:"title"`
|
||||||
}
|
}
|
||||||
err = json.Unmarshal([]byte(response), &jsonResponse)
|
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -272,7 +276,7 @@ func ShowDelayedContent(content <-chan api.Chunk) {
|
|||||||
// RenderConversation renders the given messages to TTY, with optional space
|
// RenderConversation renders the given messages to TTY, with optional space
|
||||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||||
// are printed immediately after the final message (1 if false, 2 if true)
|
// are printed immediately after the final message (1 if false, 2 if true)
|
||||||
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) {
|
||||||
l := len(messages)
|
l := len(messages)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
RenderMessage(ctx, &message)
|
RenderMessage(ctx, &message)
|
||||||
@ -283,7 +287,7 @@ func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
func RenderMessage(ctx *lmcli.Context, m *api.Message) {
|
||||||
var messageAge string
|
var messageAge string
|
||||||
if m.CreatedAt.IsZero() {
|
if m.CreatedAt.IsZero() {
|
||||||
messageAge = "now"
|
messageAge = "now"
|
||||||
@ -295,11 +299,11 @@ func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
|||||||
headerStyle := lipgloss.NewStyle().Bold(true)
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case model.MessageRoleSystem:
|
case api.MessageRoleSystem:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||||
case model.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||||
case model.MessageRoleAssistant:
|
case api.MessageRoleAssistant:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,13 +6,12 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agent"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
@ -24,7 +23,7 @@ type Context struct {
|
|||||||
Store ConversationStore
|
Store ConversationStore
|
||||||
|
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma *tty.ChromaHighlighter
|
||||||
EnabledTools []model.Tool
|
EnabledTools []api.ToolSpec
|
||||||
|
|
||||||
SystemPromptFile string
|
SystemPromptFile string
|
||||||
}
|
}
|
||||||
@ -50,9 +49,9 @@ func NewContext() (*Context, error) {
|
|||||||
|
|
||||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||||
|
|
||||||
var enabledTools []model.Tool
|
var enabledTools []api.ToolSpec
|
||||||
for _, toolName := range config.Tools.EnabledTools {
|
for _, toolName := range config.Tools.EnabledTools {
|
||||||
tool, ok := tools.AvailableTools[toolName]
|
tool, ok := agent.AvailableTools[toolName]
|
||||||
if ok {
|
if ok {
|
||||||
enabledTools = append(enabledTools, tool)
|
enabledTools = append(enabledTools, tool)
|
||||||
}
|
}
|
||||||
@ -79,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
|
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
var provider string
|
var provider string
|
||||||
|
@ -8,32 +8,32 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
sqids "github.com/sqids/sqids-go"
|
sqids "github.com/sqids/sqids-go"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConversationStore interface {
|
type ConversationStore interface {
|
||||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
ConversationByShortName(shortName string) (*api.Conversation, error)
|
||||||
ConversationShortNameCompletions(search string) []string
|
ConversationShortNameCompletions(search string) []string
|
||||||
RootMessages(conversationID uint) ([]model.Message, error)
|
RootMessages(conversationID uint) ([]api.Message, error)
|
||||||
LatestConversationMessages() ([]model.Message, error)
|
LatestConversationMessages() ([]api.Message, error)
|
||||||
|
|
||||||
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
|
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
|
||||||
UpdateConversation(conversation *model.Conversation) error
|
UpdateConversation(conversation *api.Conversation) error
|
||||||
DeleteConversation(conversation *model.Conversation) error
|
DeleteConversation(conversation *api.Conversation) error
|
||||||
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
|
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
|
||||||
|
|
||||||
MessageByID(messageID uint) (*model.Message, error)
|
MessageByID(messageID uint) (*api.Message, error)
|
||||||
MessageReplies(messageID uint) ([]model.Message, error)
|
MessageReplies(messageID uint) ([]api.Message, error)
|
||||||
|
|
||||||
UpdateMessage(message *model.Message) error
|
UpdateMessage(message *api.Message) error
|
||||||
DeleteMessage(message *model.Message, prune bool) error
|
DeleteMessage(message *api.Message, prune bool) error
|
||||||
CloneBranch(toClone model.Message) (*model.Message, uint, error)
|
CloneBranch(toClone api.Message) (*api.Message, uint, error)
|
||||||
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
|
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
|
||||||
|
|
||||||
PathToRoot(message *model.Message) ([]model.Message, error)
|
PathToRoot(message *api.Message) ([]api.Message, error)
|
||||||
PathToLeaf(message *model.Message) ([]model.Message, error)
|
PathToLeaf(message *api.Message) ([]api.Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type SQLStore struct {
|
type SQLStore struct {
|
||||||
@ -43,8 +43,8 @@ type SQLStore struct {
|
|||||||
|
|
||||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||||
models := []any{
|
models := []any{
|
||||||
&model.Conversation{},
|
&api.Conversation{},
|
||||||
&model.Message{},
|
&api.Message{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, x := range models {
|
for _, x := range models {
|
||||||
@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|||||||
return &SQLStore{db, _sqids}, nil
|
return &SQLStore{db, _sqids}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) createConversation() (*model.Conversation, error) {
|
func (s *SQLStore) createConversation() (*api.Conversation, error) {
|
||||||
// Create the new conversation
|
// Create the new conversation
|
||||||
c := &model.Conversation{}
|
c := &api.Conversation{}
|
||||||
err := s.db.Save(c).Error
|
err := s.db.Save(c).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*model.Conversation, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
|
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
|
||||||
if c == nil || c.ID == 0 {
|
if c == nil || c.ID == 0 {
|
||||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
return s.db.Updates(c).Error
|
return s.db.Updates(c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
|
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
|
||||||
// Delete messages first
|
// Delete messages first
|
||||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
|
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return s.db.Delete(c).Error
|
return s.db.Delete(c).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
|
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
|
||||||
panic("Not yet implemented")
|
panic("Not yet implemented")
|
||||||
//return s.db.Delete(&message).Error
|
//return s.db.Delete(&message).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
func (s *SQLStore) UpdateMessage(m *api.Message) error {
|
||||||
if m == nil || m.ID == 0 {
|
if m == nil || m.ID == 0 {
|
||||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||||
}
|
}
|
||||||
@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||||
var conversations []model.Conversation
|
var conversations []api.Conversation
|
||||||
// ignore error for completions
|
// ignore error for completions
|
||||||
s.db.Find(&conversations)
|
s.db.Find(&conversations)
|
||||||
completions := make([]string, 0, len(conversations))
|
completions := make([]string, 0, len(conversations))
|
||||||
@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
|||||||
return completions
|
return completions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
|
||||||
if shortName == "" {
|
if shortName == "" {
|
||||||
return nil, errors.New("shortName is empty")
|
return nil, errors.New("shortName is empty")
|
||||||
}
|
}
|
||||||
var conversation model.Conversation
|
var conversation api.Conversation
|
||||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
return &conversation, err
|
return &conversation, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
|
||||||
var rootMessages []model.Message
|
var rootMessages []api.Message
|
||||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
|
|||||||
return rootMessages, nil
|
return rootMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
|
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
|
||||||
var message model.Message
|
var message api.Message
|
||||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||||
return &message, err
|
return &message, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
|
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
|
||||||
var replies []model.Message
|
var replies []api.Message
|
||||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||||
return replies, err
|
return replies, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartConversation starts a new conversation with the provided messages
|
// StartConversation starts a new conversation with the provided messages
|
||||||
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
|
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||||
}
|
}
|
||||||
@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
messages = append([]model.Message{messages[0]}, newMessages...)
|
messages = append([]api.Message{messages[0]}, newMessages...)
|
||||||
}
|
}
|
||||||
return conversation, messages, nil
|
return conversation, messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloneConversation clones the given conversation and all of its root meesages
|
// CloneConversation clones the given conversation and all of its root meesages
|
||||||
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
|
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
|
||||||
rootMessages, err := s.RootMessages(toClone.ID)
|
rootMessages, err := s.RootMessages(toClone.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reply to a message with a series of messages (each following the next)
|
// Reply to a message with a series of messages (each following the next)
|
||||||
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
|
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
|
||||||
var savedMessages []model.Message
|
var savedMessages []api.Message
|
||||||
|
|
||||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
currentParent := to
|
currentParent := to
|
||||||
@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
|
|||||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||||
// a new message object. The new message will be attached to the same parent as
|
// a new message object. The new message will be attached to the same parent as
|
||||||
// the messageToClone
|
// the messageToClone
|
||||||
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
|
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
|
||||||
newMessage := messageToClone
|
newMessage := messageToClone
|
||||||
newMessage.ID = 0
|
newMessage.ID = 0
|
||||||
newMessage.Replies = nil
|
newMessage.Replies = nil
|
||||||
@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
|
|||||||
return &newMessage, replyCount, nil
|
return &newMessage, replyCount, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
||||||
var messages []model.Message
|
var messages []api.Message
|
||||||
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
||||||
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messageMap := make(map[uint]model.Message)
|
messageMap := make(map[uint]api.Message)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
messageMap[messages[i].ID] = message
|
messageMap[messages[i].ID] = message
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map to store replies by their parent ID
|
// Create a map to store replies by their parent ID
|
||||||
repliesMap := make(map[uint][]model.Message)
|
repliesMap := make(map[uint][]api.Message)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
if messages[i].ParentID != nil {
|
if messages[i].ParentID != nil {
|
||||||
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
||||||
@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
|||||||
// Assign replies, parent, and selected reply to each message
|
// Assign replies, parent, and selected reply to each message
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
if replies, exists := repliesMap[messages[i].ID]; exists {
|
if replies, exists := repliesMap[messages[i].ID]; exists {
|
||||||
messages[i].Replies = make([]model.Message, len(replies))
|
messages[i].Replies = make([]api.Message, len(replies))
|
||||||
for j, m := range replies {
|
for j, m := range replies {
|
||||||
messages[i].Replies[j] = m
|
messages[i].Replies[j] = m
|
||||||
}
|
}
|
||||||
@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
|
|||||||
return messages, nil
|
return messages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) {
|
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
|
||||||
var messages []model.Message
|
var messages []api.Message
|
||||||
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a map to store messages by their ID
|
// Create a map to store messages by their ID
|
||||||
messageMap := make(map[uint]*model.Message)
|
messageMap := make(map[uint]*api.Message)
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
messageMap[messages[i].ID] = &messages[i]
|
messageMap[messages[i].ID] = &messages[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the path
|
// Build the path
|
||||||
var path []model.Message
|
var path []api.Message
|
||||||
nextID := &message.ID
|
nextID := &message.ID
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
|
|||||||
// PathToRoot traverses the provided message's Parent until reaching the tree
|
// PathToRoot traverses the provided message's Parent until reaching the tree
|
||||||
// root and returns a slice of all messages traversed in chronological order
|
// root and returns a slice of all messages traversed in chronological order
|
||||||
// (starting with the root and ending with the message provided)
|
// (starting with the root and ending with the message provided)
|
||||||
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
|
||||||
if message == nil || message.ID <= 0 {
|
if message == nil || message.ID <= 0 {
|
||||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
path, err := s.buildPath(message, func(m *model.Message) *uint {
|
path, err := s.buildPath(message, func(m *api.Message) *uint {
|
||||||
return m.ParentID
|
return m.ParentID
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
|
|||||||
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
||||||
// tree leaf and returns a slice of all messages traversed in chronological
|
// tree leaf and returns a slice of all messages traversed in chronological
|
||||||
// order (starting with the message provided and ending with the leaf)
|
// order (starting with the message provided and ending with the leaf)
|
||||||
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
|
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
|
||||||
if message == nil || message.ID <= 0 {
|
if message == nil || message.ID <= 0 {
|
||||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.buildPath(message, func(m *model.Message) *uint {
|
return s.buildPath(message, func(m *api.Message) *uint {
|
||||||
return m.SelectedReplyID
|
return m.SelectedReplyID
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
|
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
|
||||||
var latestMessages []model.Message
|
var latestMessages []api.Message
|
||||||
|
|
||||||
subQuery := s.db.Model(&model.Message{}).
|
subQuery := s.db.Model(&api.Message{}).
|
||||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||||
Group("conversation_id")
|
Group("conversation_id")
|
||||||
|
|
||||||
err := s.db.Model(&model.Message{}).
|
err := s.db.Model(&api.Message{}).
|
||||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||||
Group("messages.conversation_id").
|
Group("messages.conversation_id").
|
||||||
Order("created_at DESC").
|
Order("created_at DESC").
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
package tools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
|
||||||
"dir_tree": DirTreeTool,
|
|
||||||
"read_dir": ReadDirTool,
|
|
||||||
"read_file": ReadFileTool,
|
|
||||||
"write_file": WriteFileTool,
|
|
||||||
"file_insert_lines": FileInsertLinesTool,
|
|
||||||
"file_replace_lines": FileReplaceLinesTool,
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
|
||||||
var toolResults []model.ToolResult
|
|
||||||
for _, toolCall := range toolCalls {
|
|
||||||
var tool *model.Tool
|
|
||||||
for _, available := range toolBag {
|
|
||||||
if available.Name == toolCall.Name {
|
|
||||||
tool = &available
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tool == nil {
|
|
||||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the tool
|
|
||||||
result, err := tool.Impl(tool, toolCall.Parameters)
|
|
||||||
if err != nil {
|
|
||||||
// This can happen if the model missed or supplied invalid tool args
|
|
||||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.ToolResult{
|
|
||||||
ToolCallID: toolCall.ID,
|
|
||||||
ToolName: toolCall.Name,
|
|
||||||
Result: result,
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults = append(toolResults, toolResult)
|
|
||||||
}
|
|
||||||
return toolResults, nil
|
|
||||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
"github.com/charmbracelet/bubbles/spinner"
|
"github.com/charmbracelet/bubbles/spinner"
|
||||||
@ -16,37 +15,39 @@ import (
|
|||||||
|
|
||||||
// custom tea.Msg types
|
// custom tea.Msg types
|
||||||
type (
|
type (
|
||||||
// sent on each chunk received from LLM
|
|
||||||
msgResponseChunk api.Chunk
|
|
||||||
// sent when response is finished being received
|
|
||||||
msgResponseEnd string
|
|
||||||
// a special case of common.MsgError that stops the response waiting animation
|
|
||||||
msgResponseError error
|
|
||||||
// sent on each completed reply
|
|
||||||
msgResponse models.Message
|
|
||||||
// sent when a conversation is (re)loaded
|
// sent when a conversation is (re)loaded
|
||||||
msgConversationLoaded struct {
|
msgConversationLoaded struct {
|
||||||
conversation *models.Conversation
|
conversation *api.Conversation
|
||||||
rootMessages []models.Message
|
rootMessages []api.Message
|
||||||
}
|
}
|
||||||
// sent when a new conversation title generated
|
// sent when a new conversation title generated
|
||||||
msgConversationTitleGenerated string
|
msgConversationTitleGenerated string
|
||||||
// sent when a conversation's messages are laoded
|
|
||||||
msgMessagesLoaded []models.Message
|
|
||||||
// sent when the conversation has been persisted, triggers a reload of contents
|
// sent when the conversation has been persisted, triggers a reload of contents
|
||||||
msgConversationPersisted struct {
|
msgConversationPersisted struct {
|
||||||
isNew bool
|
isNew bool
|
||||||
conversation *models.Conversation
|
conversation *api.Conversation
|
||||||
messages []models.Message
|
messages []api.Message
|
||||||
}
|
}
|
||||||
|
// sent when a conversation's messages are laoded
|
||||||
|
msgMessagesLoaded []api.Message
|
||||||
|
// a special case of common.MsgError that stops the response waiting animation
|
||||||
|
msgChatResponseError error
|
||||||
|
// sent on each chunk received from LLM
|
||||||
|
msgChatResponseChunk api.Chunk
|
||||||
|
// sent on each completed reply
|
||||||
|
msgChatResponse *api.Message
|
||||||
|
// sent when the response is canceled
|
||||||
|
msgChatResponseCanceled struct{}
|
||||||
|
// sent when results from a tool call are returned
|
||||||
|
msgToolResults []api.ToolResult
|
||||||
// sent when the given message is made the new selected reply of its parent
|
// sent when the given message is made the new selected reply of its parent
|
||||||
msgSelectedReplyCycled *models.Message
|
msgSelectedReplyCycled *api.Message
|
||||||
// sent when the given message is made the new selected root of the current conversation
|
// sent when the given message is made the new selected root of the current conversation
|
||||||
msgSelectedRootCycled *models.Message
|
msgSelectedRootCycled *api.Message
|
||||||
// sent when a message's contents are updated and saved
|
// sent when a message's contents are updated and saved
|
||||||
msgMessageUpdated *models.Message
|
msgMessageUpdated *api.Message
|
||||||
// sent when a message is cloned, with the cloned message
|
// sent when a message is cloned, with the cloned message
|
||||||
msgMessageCloned *models.Message
|
msgMessageCloned *api.Message
|
||||||
)
|
)
|
||||||
|
|
||||||
type focusState int
|
type focusState int
|
||||||
@ -77,14 +78,14 @@ type Model struct {
|
|||||||
|
|
||||||
// app state
|
// app state
|
||||||
state state // current overall status of the view
|
state state // current overall status of the view
|
||||||
conversation *models.Conversation
|
conversation *api.Conversation
|
||||||
rootMessages []models.Message
|
rootMessages []api.Message
|
||||||
messages []models.Message
|
messages []api.Message
|
||||||
selectedMessage int
|
selectedMessage int
|
||||||
editorTarget editorTarget
|
editorTarget editorTarget
|
||||||
stopSignal chan struct{}
|
stopSignal chan struct{}
|
||||||
replyChan chan models.Message
|
replyChan chan api.Message
|
||||||
replyChunkChan chan api.Chunk
|
chatReplyChunks chan api.Chunk
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
// ui state
|
// ui state
|
||||||
@ -111,12 +112,12 @@ func Chat(shared shared.Shared) Model {
|
|||||||
Shared: shared,
|
Shared: shared,
|
||||||
|
|
||||||
state: idle,
|
state: idle,
|
||||||
conversation: &models.Conversation{},
|
conversation: &api.Conversation{},
|
||||||
persistence: true,
|
persistence: true,
|
||||||
|
|
||||||
stopSignal: make(chan struct{}),
|
stopSignal: make(chan struct{}),
|
||||||
replyChan: make(chan models.Message),
|
replyChan: make(chan api.Message),
|
||||||
replyChunkChan: make(chan api.Chunk),
|
chatReplyChunks: make(chan api.Chunk),
|
||||||
|
|
||||||
wrap: true,
|
wrap: true,
|
||||||
selectedMessage: -1,
|
selectedMessage: -1,
|
||||||
@ -144,8 +145,8 @@ func Chat(shared shared.Shared) Model {
|
|||||||
|
|
||||||
system := shared.Ctx.GetSystemPrompt()
|
system := shared.Ctx.GetSystemPrompt()
|
||||||
if system != "" {
|
if system != "" {
|
||||||
m.messages = []models.Message{{
|
m.messages = []api.Message{{
|
||||||
Role: models.MessageRoleSystem,
|
Role: api.MessageRoleSystem,
|
||||||
Content: system,
|
Content: system,
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
@ -166,6 +167,5 @@ func Chat(shared shared.Shared) Model {
|
|||||||
func (m Model) Init() tea.Cmd {
|
func (m Model) Init() tea.Cmd {
|
||||||
return tea.Batch(
|
return tea.Batch(
|
||||||
m.waitForResponseChunk(),
|
m.waitForResponseChunk(),
|
||||||
m.waitForResponse(),
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -2,16 +2,18 @@ package chat
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agent"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *Model) setMessage(i int, msg models.Message) {
|
func (m *Model) setMessage(i int, msg api.Message) {
|
||||||
if i >= len(m.messages) {
|
if i >= len(m.messages) {
|
||||||
panic("i out of range")
|
panic("i out of range")
|
||||||
}
|
}
|
||||||
@ -19,7 +21,7 @@ func (m *Model) setMessage(i int, msg models.Message) {
|
|||||||
m.messageCache[i] = m.renderMessage(i)
|
m.messageCache[i] = m.renderMessage(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) addMessage(msg models.Message) {
|
func (m *Model) addMessage(msg api.Message) {
|
||||||
m.messages = append(m.messages, msg)
|
m.messages = append(m.messages, msg)
|
||||||
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
|
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
|
||||||
}
|
}
|
||||||
@ -88,7 +90,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
|
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
err := m.Shared.Ctx.Store.UpdateConversation(conversation)
|
err := m.Shared.Ctx.Store.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -101,7 +103,7 @@ func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.C
|
|||||||
// Clones the given message (and its descendents). If selected is true, updates
|
// Clones the given message (and its descendents). If selected is true, updates
|
||||||
// either its parent's SelectedReply or its conversation's SelectedRoot to
|
// either its parent's SelectedReply or its conversation's SelectedRoot to
|
||||||
// point to the new clone
|
// point to the new clone
|
||||||
func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
|
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
msg, _, err := m.Ctx.Store.CloneBranch(message)
|
msg, _, err := m.Ctx.Store.CloneBranch(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -123,7 +125,7 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
|
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
err := m.Shared.Ctx.Store.UpdateMessage(message)
|
err := m.Shared.Ctx.Store.UpdateMessage(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -133,7 +135,7 @@ func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) {
|
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
||||||
currentIndex := -1
|
currentIndex := -1
|
||||||
for i, reply := range choices {
|
for i, reply := range choices {
|
||||||
if reply.ID == selected.ID {
|
if reply.ID == selected.ID {
|
||||||
@ -158,7 +160,7 @@ func cycleSelectedMessage(selected *models.Message, choices []models.Message, di
|
|||||||
return &choices[next], nil
|
return &choices[next], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd {
|
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd {
|
||||||
if len(m.rootMessages) < 2 {
|
if len(m.rootMessages) < 2 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -178,7 +180,7 @@ func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDir
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd {
|
func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd {
|
||||||
if len(message.Replies) < 2 {
|
if len(message.Replies) < 2 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -218,15 +220,12 @@ func (m *Model) persistConversation() tea.Cmd {
|
|||||||
// else, we'll handle updating an existing conversation's messages
|
// else, we'll handle updating an existing conversation's messages
|
||||||
for i := range messages {
|
for i := range messages {
|
||||||
if messages[i].ID > 0 {
|
if messages[i].ID > 0 {
|
||||||
// message has an ID, update its contents
|
// message has an ID, update it
|
||||||
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
|
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return shared.MsgError(err)
|
return shared.MsgError(err)
|
||||||
}
|
}
|
||||||
} else if i > 0 {
|
} else if i > 0 {
|
||||||
if messages[i].Content == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// messages is new, so add it as a reply to previous message
|
// messages is new, so add it as a reply to previous message
|
||||||
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
|
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -243,13 +242,23 @@ func (m *Model) persistConversation() tea.Cmd {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
results, err := agent.ExecuteToolCalls(toolCalls, m.Ctx.EnabledTools)
|
||||||
|
if err != nil {
|
||||||
|
return shared.MsgError(err)
|
||||||
|
}
|
||||||
|
return msgToolResults(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Model) promptLLM() tea.Cmd {
|
func (m *Model) promptLLM() tea.Cmd {
|
||||||
m.state = pendingResponse
|
m.state = pendingResponse
|
||||||
m.replyCursor.Blink = false
|
m.replyCursor.Blink = false
|
||||||
|
|
||||||
m.tokenCount = 0
|
|
||||||
m.startTime = time.Now()
|
m.startTime = time.Now()
|
||||||
m.elapsed = 0
|
m.elapsed = 0
|
||||||
|
m.tokenCount = 0
|
||||||
|
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
|
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
|
||||||
@ -257,36 +266,34 @@ func (m *Model) promptLLM() tea.Cmd {
|
|||||||
return shared.MsgError(err)
|
return shared.MsgError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := models.RequestParameters{
|
requestParams := api.RequestParameters{
|
||||||
Model: model,
|
Model: model,
|
||||||
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
|
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
|
||||||
ToolBag: m.Shared.Ctx.EnabledTools,
|
ToolBag: m.Shared.Ctx.EnabledTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
replyHandler := func(msg models.Message) {
|
|
||||||
m.replyChan <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
canceled := false
|
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-m.stopSignal:
|
case <-m.stopSignal:
|
||||||
canceled = true
|
|
||||||
cancel()
|
cancel()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
resp, err := provider.CreateChatCompletionStream(
|
resp, err := provider.CreateChatCompletionStream(
|
||||||
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
|
ctx, requestParams, m.messages, m.chatReplyChunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil && !canceled {
|
if errors.Is(err, context.Canceled) {
|
||||||
return msgResponseError(err)
|
return msgChatResponseCanceled(struct{}{})
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgResponseEnd(resp)
|
if err != nil {
|
||||||
|
return msgChatResponseError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgChatResponse(resp)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
@ -150,12 +150,12 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
|
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser {
|
||||||
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addMessage(models.Message{
|
m.addMessage(api.Message{
|
||||||
Role: models.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
@ -21,15 +21,9 @@ func (m *Model) HandleResize(width, height int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) waitForResponse() tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
return msgResponse(<-m.replyChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) waitForResponseChunk() tea.Cmd {
|
func (m *Model) waitForResponseChunk() tea.Cmd {
|
||||||
return func() tea.Msg {
|
return func() tea.Msg {
|
||||||
return msgResponseChunk(<-m.replyChunkChan)
|
return msgChatResponseChunk(<-m.chatReplyChunks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +42,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
|
|
||||||
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
|
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
|
||||||
// clear existing messages if we're loading a new conversation
|
// clear existing messages if we're loading a new conversation
|
||||||
m.messages = []models.Message{}
|
m.messages = []api.Message{}
|
||||||
m.selectedMessage = 0
|
m.selectedMessage = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -87,7 +81,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
m.rebuildMessageCache()
|
m.rebuildMessageCache()
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
case msgResponseChunk:
|
case msgChatResponseChunk:
|
||||||
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
||||||
|
|
||||||
if msg.Content == "" {
|
if msg.Content == "" {
|
||||||
@ -100,8 +94,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
m.setMessageContents(last, m.messages[last].Content+msg.Content)
|
m.setMessageContents(last, m.messages[last].Content+msg.Content)
|
||||||
} else {
|
} else {
|
||||||
// use chunk in new message
|
// use chunk in new message
|
||||||
m.addMessage(models.Message{
|
m.addMessage(api.Message{
|
||||||
Role: models.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
Content: msg.Content,
|
Content: msg.Content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -113,10 +107,10 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
|
|
||||||
m.tokenCount += msg.TokenCount
|
m.tokenCount += msg.TokenCount
|
||||||
m.elapsed = time.Now().Sub(m.startTime)
|
m.elapsed = time.Now().Sub(m.startTime)
|
||||||
case msgResponse:
|
case msgChatResponse:
|
||||||
cmds = append(cmds, m.waitForResponse()) // wait for the next response
|
m.state = idle
|
||||||
|
|
||||||
reply := models.Message(msg)
|
reply := (*api.Message)(msg)
|
||||||
reply.Content = strings.TrimSpace(reply.Content)
|
reply.Content = strings.TrimSpace(reply.Content)
|
||||||
|
|
||||||
last := len(m.messages) - 1
|
last := len(m.messages) - 1
|
||||||
@ -124,11 +118,18 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
panic("Unexpected empty messages handling msgAssistantReply")
|
panic("Unexpected empty messages handling msgAssistantReply")
|
||||||
}
|
}
|
||||||
|
|
||||||
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
|
if m.messages[last].Role.IsAssistant() {
|
||||||
// this was a continuation, so replace the previous message with the completed reply
|
// TODO: handle continuations gracefully - some models support them well, others fail horribly.
|
||||||
m.setMessage(last, reply)
|
m.setMessage(last, *reply)
|
||||||
} else {
|
} else {
|
||||||
m.addMessage(reply)
|
m.addMessage(*reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch reply.Role {
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
// TODO: user confirmation before execution
|
||||||
|
// m.state = waitingForConfirmation
|
||||||
|
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.persistence {
|
if m.persistence {
|
||||||
@ -140,17 +141,32 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
case msgResponseEnd:
|
case msgChatResponseCanceled:
|
||||||
m.state = idle
|
m.state = idle
|
||||||
last := len(m.messages) - 1
|
|
||||||
if last < 0 {
|
|
||||||
panic("Unexpected empty messages handling msgResponseEnd")
|
|
||||||
}
|
|
||||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
case msgResponseError:
|
case msgChatResponseError:
|
||||||
m.state = idle
|
m.state = idle
|
||||||
m.Shared.Err = error(msg)
|
m.Shared.Err = error(msg)
|
||||||
|
m.updateContent()
|
||||||
|
case msgToolResults:
|
||||||
|
last := len(m.messages) - 1
|
||||||
|
if last < 0 {
|
||||||
|
panic("Unexpected empty messages handling msgAssistantReply")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.messages[last].Role != api.MessageRoleToolCall {
|
||||||
|
panic("Previous message not a tool call, unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addMessage(api.Message{
|
||||||
|
Role: api.MessageRoleToolResult,
|
||||||
|
ToolResults: api.ToolResults(msg),
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.persistence {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
|
}
|
||||||
|
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
case msgConversationTitleGenerated:
|
case msgConversationTitleGenerated:
|
||||||
title := string(msg)
|
title := string(msg)
|
||||||
@ -167,7 +183,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
m.conversation = msg.conversation
|
m.conversation = msg.conversation
|
||||||
m.messages = msg.messages
|
m.messages = msg.messages
|
||||||
if msg.isNew {
|
if msg.isNew {
|
||||||
m.rootMessages = []models.Message{m.messages[0]}
|
m.rootMessages = []api.Message{m.messages[0]}
|
||||||
}
|
}
|
||||||
m.rebuildMessageCache()
|
m.rebuildMessageCache()
|
||||||
m.updateContent()
|
m.updateContent()
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
@ -63,22 +63,22 @@ func (m Model) View() string {
|
|||||||
return lipgloss.JoinVertical(lipgloss.Left, sections...)
|
return lipgloss.JoinVertical(lipgloss.Left, sections...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) renderMessageHeading(i int, message *models.Message) string {
|
func (m *Model) renderMessageHeading(i int, message *api.Message) string {
|
||||||
icon := ""
|
icon := ""
|
||||||
friendly := message.Role.FriendlyRole()
|
friendly := message.Role.FriendlyRole()
|
||||||
style := lipgloss.NewStyle().Faint(true).Bold(true)
|
style := lipgloss.NewStyle().Faint(true).Bold(true)
|
||||||
|
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case models.MessageRoleSystem:
|
case api.MessageRoleSystem:
|
||||||
icon = "⚙️"
|
icon = "⚙️"
|
||||||
case models.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
style = userStyle
|
style = userStyle
|
||||||
case models.MessageRoleAssistant:
|
case api.MessageRoleAssistant:
|
||||||
style = assistantStyle
|
style = assistantStyle
|
||||||
case models.MessageRoleToolCall:
|
case api.MessageRoleToolCall:
|
||||||
style = assistantStyle
|
style = assistantStyle
|
||||||
friendly = models.MessageRoleAssistant.FriendlyRole()
|
friendly = api.MessageRoleAssistant.FriendlyRole()
|
||||||
case models.MessageRoleToolResult:
|
case api.MessageRoleToolResult:
|
||||||
icon = "🔧"
|
icon = "🔧"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,21 +139,21 @@ func (m *Model) renderMessage(i int) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Show the assistant's cursor
|
// Show the assistant's cursor
|
||||||
if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
|
if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == api.MessageRoleAssistant {
|
||||||
sb.WriteString(m.replyCursor.View())
|
sb.WriteString(m.replyCursor.View())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write tool call info
|
// Write tool call info
|
||||||
var toolString string
|
var toolString string
|
||||||
switch msg.Role {
|
switch msg.Role {
|
||||||
case models.MessageRoleToolCall:
|
case api.MessageRoleToolCall:
|
||||||
bytes, err := yaml.Marshal(msg.ToolCalls)
|
bytes, err := yaml.Marshal(msg.ToolCalls)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
toolString = "Could not serialize ToolCalls"
|
toolString = "Could not serialize ToolCalls"
|
||||||
} else {
|
} else {
|
||||||
toolString = "tool_calls:\n" + string(bytes)
|
toolString = "tool_calls:\n" + string(bytes)
|
||||||
}
|
}
|
||||||
case models.MessageRoleToolResult:
|
case api.MessageRoleToolResult:
|
||||||
if !m.showToolResults {
|
if !m.showToolResults {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -221,11 +221,11 @@ func (m *Model) conversationMessagesView() string {
|
|||||||
m.messageOffsets[i] = lineCnt
|
m.messageOffsets[i] = lineCnt
|
||||||
|
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case models.MessageRoleToolCall:
|
case api.MessageRoleToolCall:
|
||||||
if !m.showToolResults && message.Content == "" {
|
if !m.showToolResults && message.Content == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
case models.MessageRoleToolResult:
|
case api.MessageRoleToolResult:
|
||||||
if !m.showToolResults {
|
if !m.showToolResults {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -251,9 +251,9 @@ func (m *Model) conversationMessagesView() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Render a placeholder for the incoming assistant reply
|
// Render a placeholder for the incoming assistant reply
|
||||||
if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
|
if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant) {
|
||||||
heading := m.renderMessageHeading(-1, &models.Message{
|
heading := m.renderMessageHeading(-1, &api.Message{
|
||||||
Role: models.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
})
|
})
|
||||||
sb.WriteString(heading)
|
sb.WriteString(heading)
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
|
@ -5,7 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
@ -16,15 +16,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type loadedConversation struct {
|
type loadedConversation struct {
|
||||||
conv models.Conversation
|
conv api.Conversation
|
||||||
lastReply models.Message
|
lastReply api.Message
|
||||||
}
|
}
|
||||||
|
|
||||||
type (
|
type (
|
||||||
// sent when conversation list is loaded
|
// sent when conversation list is loaded
|
||||||
msgConversationsLoaded ([]loadedConversation)
|
msgConversationsLoaded ([]loadedConversation)
|
||||||
// sent when a conversation is selected
|
// sent when a conversation is selected
|
||||||
msgConversationSelected models.Conversation
|
msgConversationSelected api.Conversation
|
||||||
)
|
)
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
|
Loading…
Reference in New Issue
Block a user