Package restructure and API changes, several fixes

- More emphasis on `api` package. It now holds database model structs
  from `lmcli/models` (which is now gone) as well as the tool spec,
  call, and result types. `tools.Tool` is now `api.ToolSpec`.
  `api.ChatCompletionClient` was renamed to
  `api.ChatCompletionProvider`.

- Change ChatCompletion interface and implementations to no longer do
  automatic tool call recursion - they simply return a ToolCall message
  which the caller can decide what to do with (e.g. prompt for user
  confirmation before executing)

- `api.ChatCompletionProvider` functions have had their ReplyCallback
  parameter removed, as now they only return a single reply.

- Added a top-level `agent` package, moved the current built-in tools
  implementations under `agent/toolbox`. `tools.ExecuteToolCalls` is now
  `agent.ExecuteToolCalls`.

- Fixed request context handling in openai, google, ollama (use
  `NewRequestWithContext`), cleaned up request cancellation in TUI

- Fix tool call tui persistence bug (we were skipping message with empty
  content)

- Now handle tool calling from TUI layer

TODO:
- Prompt users before executing tool calls
- Automatically send tool results to the model (or make this toggleable)
This commit is contained in:
Matt Low 2024-06-12 08:35:07 +00:00
parent 85a2abbbf3
commit 3fde58b77d
35 changed files with 608 additions and 749 deletions

View File

@ -1,4 +1,4 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
@ -7,8 +7,8 @@ import (
"strconv" "strconv"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents. const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
@ -27,10 +27,10 @@ Example result:
} }
` `
var DirTreeTool = model.Tool{ var DirTreeTool = api.ToolSpec{
Name: "dir_tree", Name: "dir_tree",
Description: TREE_DESCRIPTION, Description: TREE_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "relative_path", Name: "relative_path",
Type: "string", Type: "string",
@ -42,7 +42,7 @@ var DirTreeTool = model.Tool{
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.", Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
if tmp, ok := args["relative_path"]; ok { if tmp, ok := args["relative_path"]; ok {
relativeDir, ok = tmp.(string) relativeDir, ok = tmp.(string)
@ -76,25 +76,25 @@ var DirTreeTool = model.Tool{
}, },
} }
func tree(path string, depth int) model.CallResult { func tree(path string, depth int) api.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
var treeOutput strings.Builder var treeOutput strings.Builder
treeOutput.WriteString(path + "\n") treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth) err := buildTree(&treeOutput, path, "", depth)
if err != nil { if err != nil {
return model.CallResult{ return api.CallResult{
Message: err.Error(), Message: err.Error(),
} }
} }
return model.CallResult{Result: treeOutput.String()} return api.CallResult{Result: treeOutput.String()}
} }
func buildTree(output *strings.Builder, path string, prefix string, depth int) error { func buildTree(output *strings.Builder, path string, prefix string, depth int) error {

View File

@ -1,22 +1,22 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.` Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = model.Tool{ var FileInsertLinesTool = api.ToolSpec{
Name: "file_insert_lines", Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION, Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -36,7 +36,7 @@ var FileInsertLinesTool = model.Tool{
Required: true, Required: true,
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.") return "", fmt.Errorf("path parameter to write_file was not included.")
@ -72,27 +72,27 @@ var FileInsertLinesTool = model.Tool{
}, },
} }
func fileInsertLines(path string, position int, content string) model.CallResult { func fileInsertLines(path string, position int, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
// Read the existing file's content // Read the existing file's content
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
_, err = os.Create(path) _, err = os.Create(path)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
} }
data = []byte{} data = []byte{}
} }
if position < 1 { if position < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"} return api.CallResult{Message: "start_line cannot be less than 1"}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) model.CallResult
// Join the lines and write back to the file // Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644) err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return model.CallResult{Result: newContent} return api.CallResult{Result: newContent}
} }

View File

@ -1,12 +1,12 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
@ -15,10 +15,10 @@ Useful for re-writing snippets/blocks of code or entire functions.
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.` Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
var FileReplaceLinesTool = model.Tool{ var FileReplaceLinesTool = api.ToolSpec{
Name: "file_replace_lines", Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION, Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -42,7 +42,7 @@ var FileReplaceLinesTool = model.Tool{
Description: `Content to replace specified range. Omit to remove the specified range.`, Description: `Content to replace specified range. Omit to remove the specified range.`,
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.") return "", fmt.Errorf("path parameter to write_file was not included.")
@ -87,27 +87,27 @@ var FileReplaceLinesTool = model.Tool{
}, },
} }
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult { func fileReplaceLines(path string, startLine int, endLine int, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
// Read the existing file's content // Read the existing file's content
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
_, err = os.Create(path) _, err = os.Create(path)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
} }
data = []byte{} data = []byte{}
} }
if startLine < 1 { if startLine < 1 {
return model.CallResult{Message: "start_line cannot be less than 1"} return api.CallResult{Message: "start_line cannot be less than 1"}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -126,8 +126,8 @@ func fileReplaceLines(path string, startLine int, endLine int, content string) m
// Join the lines and write back to the file // Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644) err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return model.CallResult{Result: newContent} return api.CallResult{Result: newContent}
} }

View File

@ -1,4 +1,4 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
@ -6,8 +6,8 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
@ -25,17 +25,17 @@ Example result:
For files, size represents the size of the file, in bytes. For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.` For directories, size represents the number of entries in that directory.`
var ReadDirTool = model.Tool{ var ReadDirTool = api.ToolSpec{
Name: "read_dir", Name: "read_dir",
Description: READ_DIR_DESCRIPTION, Description: READ_DIR_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "relative_dir", Name: "relative_dir",
Type: "string", Type: "string",
Description: "If set, read the contents of a directory relative to the current one.", Description: "If set, read the contents of a directory relative to the current one.",
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
tmp, ok := args["relative_dir"] tmp, ok := args["relative_dir"]
if ok { if ok {
@ -53,18 +53,18 @@ var ReadDirTool = model.Tool{
}, },
} }
func readDir(path string) model.CallResult { func readDir(path string) api.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
files, err := os.ReadDir(path) files, err := os.ReadDir(path)
if err != nil { if err != nil {
return model.CallResult{ return api.CallResult{
Message: err.Error(), Message: err.Error(),
} }
} }
@ -96,5 +96,5 @@ func readDir(path string) model.CallResult {
}) })
} }
return model.CallResult{Result: dirContents} return api.CallResult{Result: dirContents}
} }

View File

@ -1,12 +1,12 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory. const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
@ -21,10 +21,10 @@ Example result:
"result": "1\tthe contents\n2\tof the file\n" "result": "1\tthe contents\n2\tof the file\n"
}` }`
var ReadFileTool = model.Tool{ var ReadFileTool = api.ToolSpec{
Name: "read_file", Name: "read_file",
Description: READ_FILE_DESCRIPTION, Description: READ_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -33,7 +33,7 @@ var ReadFileTool = model.Tool{
}, },
}, },
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) { Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.") return "", fmt.Errorf("Path parameter to read_file was not included.")
@ -51,14 +51,14 @@ var ReadFileTool = model.Tool{
}, },
} }
func readFile(path string) model.CallResult { func readFile(path string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -67,7 +67,7 @@ func readFile(path string) model.CallResult {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
} }
return model.CallResult{ return api.CallResult{
Result: content.String(), Result: content.String(),
} }
} }

View File

@ -1,11 +1,11 @@
package tools package toolbox
import ( import (
"fmt" "fmt"
"os" "os"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
@ -15,10 +15,10 @@ Example result:
"message": "success" "message": "success"
}` }`
var WriteFileTool = model.Tool{ var WriteFileTool = api.ToolSpec{
Name: "write_file", Name: "write_file",
Description: WRITE_FILE_DESCRIPTION, Description: WRITE_FILE_DESCRIPTION,
Parameters: []model.ToolParameter{ Parameters: []api.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -32,7 +32,7 @@ var WriteFileTool = model.Tool{
Required: true, Required: true,
}, },
}, },
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) { Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.") return "", fmt.Errorf("Path parameter to write_file was not included.")
@ -58,14 +58,14 @@ var WriteFileTool = model.Tool{
}, },
} }
func writeFile(path string, content string) model.CallResult { func writeFile(path string, content string) api.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return model.CallResult{Message: reason} return api.CallResult{Message: reason}
} }
err := os.WriteFile(path, []byte(content), 0644) err := os.WriteFile(path, []byte(content), 0644)
if err != nil { if err != nil {
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return model.CallResult{} return api.CallResult{}
} }

48
pkg/agent/tools.go Normal file
View File

@ -0,0 +1,48 @@
package agent
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agent/toolbox"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
"dir_tree": toolbox.DirTreeTool,
"read_dir": toolbox.ReadDirTool,
"read_file": toolbox.ReadFileTool,
"write_file": toolbox.WriteFileTool,
"file_insert_lines": toolbox.FileInsertLinesTool,
"file_replace_lines": toolbox.FileReplaceLinesTool,
}
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
var toolResults []api.ToolResult
for _, call := range calls {
var tool *api.ToolSpec
for i := range available {
if available[i].Name == call.Name {
tool = &available[i]
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
}
// Execute the tool
result, err := tool.Impl(tool, call.Parameters)
if err != nil {
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
}
toolResult := api.ToolResult{
ToolCallID: call.ID,
ToolName: call.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -2,35 +2,41 @@ package api
import ( import (
"context" "context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type ReplyCallback func(model.Message) type ReplyCallback func(Message)
type Chunk struct { type Chunk struct {
Content string Content string
TokenCount uint TokenCount uint
} }
type ChatCompletionClient interface { type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
ToolBag []ToolSpec
}
type ChatCompletionProvider interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string. // complete user-facing response is returned as a string.
CreateChatCompletion( CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params RequestParameters,
messages []model.Message, messages []Message,
callback ReplyCallback, ) (*Message, error)
) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received. // the output channel as it's received.
CreateChatCompletionStream( CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params RequestParameters,
messages []model.Message, messages []Message,
callback ReplyCallback, chunks chan<- Chunk,
output chan<- Chunk, ) (*Message, error)
) (string, error)
} }

11
pkg/api/conversation.go Normal file
View File

@ -0,0 +1,11 @@
package api
import "database/sql"
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}

View File

@ -1,7 +1,6 @@
package model package api
import ( import (
"database/sql"
"time" "time"
) )
@ -32,24 +31,6 @@ type Message struct {
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
ToolBag []Tool
}
func (m *MessageRole) IsAssistant() bool { func (m *MessageRole) IsAssistant() bool {
switch *m { switch *m {
case MessageRoleAssistant, MessageRoleToolCall: case MessageRoleAssistant, MessageRoleToolCall:

View File

@ -11,11 +11,9 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func buildRequest(params model.RequestParameters, messages []model.Message) Request { func buildRequest(params api.RequestParameters, messages []api.Message) Request {
requestBody := Request{ requestBody := Request{
Model: params.Model, Model: params.Model,
Messages: make([]Message, len(messages)), Messages: make([]Message, len(messages)),
@ -30,7 +28,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
} }
startIdx := 0 startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
requestBody.System = messages[0].Content requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:] requestBody.Messages = requestBody.Messages[1:]
startIdx = 1 startIdx = 1
@ -48,7 +46,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
message := &requestBody.Messages[i] message := &requestBody.Messages[i]
switch msg.Role { switch msg.Role {
case model.MessageRoleToolCall: case api.MessageRoleToolCall:
message.Role = "assistant" message.Role = "assistant"
if msg.Content != "" { if msg.Content != "" {
message.Content = msg.Content message.Content = msg.Content
@ -63,7 +61,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
} else { } else {
message.Content = xmlString message.Content = xmlString
} }
case model.MessageRoleToolResult: case api.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString() xmlString, err := xmlFuncResults.XMLString()
if err != nil { if err != nil {
@ -105,26 +103,25 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback, ) (*api.Message, error) {
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
var response Response var response Response
err = json.NewDecoder(resp.Body).Decode(&response) err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode response: %v", err) return nil, fmt.Errorf("failed to decode response: %v", err)
} }
sb := strings.Builder{} sb := strings.Builder{}
@ -137,34 +134,28 @@ func (c *AnthropicClient) CreateChatCompletion(
} }
for _, content := range response.Content { for _, content := range response.Content {
var reply model.Message
switch content.Type { switch content.Type {
case "text": case "text":
reply = model.Message{ sb.WriteString(content.Text)
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default: default:
return "", fmt.Errorf("unsupported message type: %s", content.Type) return nil, fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
} }
} }
return sb.String(), nil return &api.Message{
Role: api.MessageRoleAssistant,
Content: sb.String(),
}, nil
} }
func (c *AnthropicClient) CreateChatCompletionStream( func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (string, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -172,19 +163,18 @@ func (c *AnthropicClient) CreateChatCompletionStream(
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
sb := strings.Builder{} sb := strings.Builder{}
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() { if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll // this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result // include its contents in the final result
// TODO: handle this at higher level
sb.WriteString(lastMessage.Content) sb.WriteString(lastMessage.Content)
continuation = true
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
@ -200,29 +190,29 @@ func (c *AnthropicClient) CreateChatCompletionStream(
var event map[string]interface{} var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event) err := json.Unmarshal([]byte(line), &event)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
} }
eventType, ok := event["type"].(string) eventType, ok := event["type"].(string)
if !ok { if !ok {
return "", fmt.Errorf("invalid event: %s", line) return nil, fmt.Errorf("invalid event: %s", line)
} }
switch eventType { switch eventType {
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return nil, fmt.Errorf("an error occurred: %s", event["error"])
default: default:
return sb.String(), fmt.Errorf("unknown event type: %s", eventType) return nil, fmt.Errorf("unknown event type: %s", eventType)
} }
} else if strings.HasPrefix(line, "data:") { } else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{} var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event) err := json.Unmarshal([]byte(data), &event)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal event data: %v", err) return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
} }
eventType, ok := event["type"].(string) eventType, ok := event["type"].(string)
if !ok { if !ok {
return "", fmt.Errorf("invalid event type") return nil, fmt.Errorf("invalid event type")
} }
switch eventType { switch eventType {
@ -235,11 +225,11 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "content_block_delta": case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{}) delta, ok := event["delta"].(map[string]interface{})
if !ok { if !ok {
return "", fmt.Errorf("invalid content block delta") return nil, fmt.Errorf("invalid content block delta")
} }
text, ok := delta["text"].(string) text, ok := delta["text"].(string)
if !ok { if !ok {
return "", fmt.Errorf("invalid text delta") return nil, fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- api.Chunk{ output <- api.Chunk{
@ -251,7 +241,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_delta": case "message_delta":
delta, ok := event["delta"].(map[string]interface{}) delta, ok := event["delta"].(map[string]interface{})
if !ok { if !ok {
return "", fmt.Errorf("invalid message delta") return nil, fmt.Errorf("invalid message delta")
} }
stopReason, ok := delta["stop_reason"].(string) stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" { if ok && stopReason == "stop_sequence" {
@ -261,7 +251,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
start := strings.Index(content, "<function_calls>") start := strings.Index(content, "<function_calls>")
if start == -1 { if start == -1 {
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
@ -269,59 +259,31 @@ func (c *AnthropicClient) CreateChatCompletionStream(
Content: FUNCTION_STOP_SEQUENCE, Content: FUNCTION_STOP_SEQUENCE,
TokenCount: 1, TokenCount: 1,
} }
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
} }
toolCall := model.Message{ return &api.Message{
Role: model.MessageRoleToolCall, Role: api.MessageRoleToolCall,
// function call xml stripped from content for model interop // function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]), Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
} }, nil
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
content := sb.String() content := sb.String()
if callback != nil { return &api.Message{
callback(model.Message{ Role: api.MessageRoleAssistant,
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}) }, nil
}
return content, nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return nil, fmt.Errorf("an error occurred: %s", event["error"])
default: default:
fmt.Printf("\nUnrecognized event: %s\n", data) fmt.Printf("\nUnrecognized event: %s\n", data)
} }
@ -329,8 +291,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return "", fmt.Errorf("failed to read response body: %v", err) return nil, fmt.Errorf("failed to read response body: %v", err)
} }
return "", fmt.Errorf("unexpected end of stream") return nil, fmt.Errorf("unexpected end of stream")
} }

View File

@ -6,7 +6,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
const FUNCTION_STOP_SEQUENCE = "</function_calls>" const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} {
return ret return ret
} }
func convertToolsToXMLTools(tools []model.Tool) XMLTools { func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
converted := make([]XMLToolDescription, len(tools)) converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools { for i, tool := range tools {
converted[i].ToolName = tool.Name converted[i].ToolName = tool.Name
@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []model.Tool) XMLTools {
} }
} }
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall { func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke)) toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke { for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String) toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.
return toolCalls return toolCalls
} }
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls { func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls)) converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls { for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters var params XMLFunctionInvokeParameters
@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionC
} }
} }
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults { func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults)) converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults { for i, result := range toolResults {
converted[i].ToolName = result.ToolName converted[i].ToolName = result.ToolName
@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFu
} }
} }
func buildToolsSystemPrompt(tools []model.Tool) string { func buildToolsSystemPrompt(tools []api.ToolSpec) string {
xmlTools := convertToolsToXMLTools(tools) xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString() xmlToolsString, err := xmlTools.XMLString()
if err != nil { if err != nil {
panic("Could not serialize []model.Tool to XMLTools") panic("Could not serialize []api.Tool to XMLTools")
} }
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
} }

View File

@ -11,11 +11,9 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func convertTools(tools []model.Tool) []Tool { func convertTools(tools []api.ToolSpec) []Tool {
geminiTools := make([]Tool, len(tools)) geminiTools := make([]Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
params := make(map[string]ToolParameter) params := make(map[string]ToolParameter)
@ -50,7 +48,7 @@ func convertTools(tools []model.Tool) []Tool {
return geminiTools return geminiTools
} }
func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart { func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls)) converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
args := make(map[string]string) args := make(map[string]string)
@ -65,8 +63,8 @@ func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
return converted return converted
} }
func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall { func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
converted := make([]model.ToolCall, len(functionCalls)) converted := make([]api.ToolCall, len(functionCalls))
for i, call := range functionCalls { for i, call := range functionCalls {
params := make(map[string]interface{}) params := make(map[string]interface{})
for k, v := range call.Args { for k, v := range call.Args {
@ -78,7 +76,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
return converted return converted
} }
func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) { func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults)) results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults { for i, result := range toolResults {
var obj interface{} var obj interface{}
@ -95,14 +93,14 @@ func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionRespo
} }
func createGenerateContentRequest( func createGenerateContentRequest(
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
) (*GenerateContentRequest, error) { ) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages)) requestContents := make([]Content, 0, len(messages))
startIdx := 0 startIdx := 0
var system string var system string
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
system = messages[0].Content system = messages[0].Content
startIdx = 1 startIdx = 1
} }
@ -135,9 +133,9 @@ func createGenerateContentRequest(
default: default:
var role string var role string
switch m.Role { switch m.Role {
case model.MessageRoleAssistant: case api.MessageRoleAssistant:
role = "model" role = "model"
case model.MessageRoleUser: case api.MessageRoleUser:
role = "user" role = "user"
} }
@ -183,55 +181,14 @@ func createGenerateContentRequest(
return request, nil return request, nil
} }
func handleToolCalls( func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
params model.RequestParameters,
content string,
toolCalls []model.ToolCall,
callback api.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx)) resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body) bytes, _ := io.ReadAll(resp.Body)
@ -243,42 +200,41 @@ func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Resp
func (c *Client) CreateChatCompletion( func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback, ) (*api.Message, error) {
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req, err := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil { if err != nil {
return "", err return nil, err
} }
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
url := fmt.Sprintf( url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s", "%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey, c.BaseURL, params.Model, c.APIKey,
) )
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp GenerateContentResponse var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return "", err return nil, err
} }
choice := completionResp.Candidates[0] choice := completionResp.Candidates[0]
@ -301,58 +257,50 @@ func (c *Client) CreateChatCompletion(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls( return &api.Message{
params, content, convertToolCallToAPI(toolCalls), callback, messages, Role: api.MessageRoleToolCall,
)
if err != nil {
return content, err
}
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}) ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return content, nil return &api.Message{
Role: api.MessageRoleAssistant,
Content: content,
}, nil
} }
func (c *Client) CreateChatCompletionStream( func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (string, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req, err := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil { if err != nil {
return "", err return nil, err
} }
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
url := fmt.Sprintf( url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey, c.BaseURL, params.Model, c.APIKey,
) )
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -374,7 +322,7 @@ func (c *Client) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return "", err return nil, err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -387,7 +335,7 @@ func (c *Client) CreateChatCompletionStream(
var resp GenerateContentResponse var resp GenerateContentResponse
err = json.Unmarshal(line, &resp) err = json.Unmarshal(line, &resp)
if err != nil { if err != nil {
return "", err return nil, err
} }
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
@ -409,21 +357,15 @@ func (c *Client) CreateChatCompletionStream(
// If there are function calls, handle them and recurse // If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls( return &api.Message{
params, content.String(), convertToolCallToAPI(toolCalls), callback, messages, Role: api.MessageRoleToolCall,
)
if err != nil {
return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return content.String(), nil return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
} }

View File

@ -11,7 +11,6 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type OllamaClient struct { type OllamaClient struct {
@ -43,8 +42,8 @@ type OllamaResponse struct {
} }
func createOllamaRequest( func createOllamaRequest(
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
) OllamaRequest { ) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages)) requestMessages := make([]OllamaMessage, 0, len(messages))
@ -64,11 +63,11 @@ func createOllamaRequest(
return request return request
} }
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) { func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx)) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -83,12 +82,11 @@ func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*htt
func (c *OllamaClient) CreateChatCompletion( func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback, ) (*api.Message, error) {
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req := createOllamaRequest(params, messages) req := createOllamaRequest(params, messages)
@ -96,46 +94,40 @@ func (c *OllamaClient) CreateChatCompletion(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp OllamaResponse var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return "", err return nil, err
} }
content := completionResp.Message.Content return &api.Message{
if callback != nil { Role: api.MessageRoleAssistant,
callback(model.Message{ Content: completionResp.Message.Content,
Role: model.MessageRoleAssistant, }, nil
Content: content,
})
}
return content, nil
} }
func (c *OllamaClient) CreateChatCompletionStream( func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (string, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req := createOllamaRequest(params, messages) req := createOllamaRequest(params, messages)
@ -143,17 +135,17 @@ func (c *OllamaClient) CreateChatCompletionStream(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -166,7 +158,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return "", err return nil, err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -177,7 +169,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
var streamResp OllamaResponse var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp) err = json.Unmarshal(line, &streamResp)
if err != nil { if err != nil {
return "", err return nil, err
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
@ -189,12 +181,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
} }
if callback != nil { return &api.Message{
callback(model.Message{ Role: api.MessageRoleAssistant,
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) }, nil
}
return content.String(), nil
} }

View File

@ -11,11 +11,9 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func convertTools(tools []model.Tool) []Tool { func convertTools(tools []api.ToolSpec) []Tool {
openaiTools := make([]Tool, len(tools)) openaiTools := make([]Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
openaiTools[i].Type = "function" openaiTools[i].Type = "function"
@ -47,7 +45,7 @@ func convertTools(tools []model.Tool) []Tool {
return openaiTools return openaiTools
} }
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall { func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
converted := make([]ToolCall, len(toolCalls)) converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].Type = "function" converted[i].Type = "function"
@ -60,8 +58,8 @@ func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
return converted return converted
} }
func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall { func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
converted := make([]model.ToolCall, len(toolCalls)) converted := make([]api.ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].ID = call.ID converted[i].ID = call.ID
converted[i].Name = call.Function.Name converted[i].Name = call.Function.Name
@ -71,8 +69,8 @@ func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
) ChatCompletionRequest { ) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
@ -117,56 +115,15 @@ func createChatCompletionRequest(
return request return request
} }
func handleToolCalls( func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) {
params model.RequestParameters,
content string,
toolCalls []ToolCall,
callback api.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey) req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx)) resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body) bytes, _ := io.ReadAll(resp.Body)
@ -178,35 +135,34 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*htt
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback, ) (*api.Message, error) {
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req := createChatCompletionRequest(params, messages) req := createChatCompletionRequest(params, messages)
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp ChatCompletionResponse var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return "", err return nil, err
} }
choice := completionResp.Choices[0] choice := completionResp.Choices[0]
@ -221,34 +177,27 @@ func (c *OpenAIClient) CreateChatCompletion(
toolCalls := choice.Message.ToolCalls toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content, toolCalls, callback, messages) return &api.Message{
if err != nil { Role: api.MessageRoleToolCall,
return content, err
}
return c.CreateChatCompletion(ctx, params, messages, callback)
}
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}) ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
// Return the user-facing message. return &api.Message{
return content, nil Role: api.MessageRoleAssistant,
Content: content,
}, nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params api.RequestParameters,
messages []model.Message, messages []api.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (string, error) { ) (*api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return nil, fmt.Errorf("Can't create completion from no messages")
} }
req := createChatCompletionRequest(params, messages) req := createChatCompletionRequest(params, messages)
@ -256,17 +205,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return "", err return nil, err
} }
resp, err := c.sendRequest(ctx, httpReq) resp, err := c.sendRequest(httpReq)
if err != nil { if err != nil {
return "", err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -285,7 +234,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return "", err return nil, err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -301,7 +250,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
var streamResp ChatCompletionStreamResponse var streamResp ChatCompletionStreamResponse
err = json.Unmarshal(line, &streamResp) err = json.Unmarshal(line, &streamResp)
if err != nil { if err != nil {
return "", err return nil, err
} }
delta := streamResp.Choices[0].Delta delta := streamResp.Choices[0].Delta
@ -309,7 +258,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
// Construct streamed tool_call arguments // Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls { for _, tc := range delta.ToolCalls {
if tc.Index == nil { if tc.Index == nil {
return "", fmt.Errorf("Unexpected nil index for streamed tool call.") return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
} }
if len(toolCalls) <= *tc.Index { if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc) toolCalls = append(toolCalls, tc)
@ -328,21 +277,15 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) return &api.Message{
if err != nil { Role: api.MessageRoleToolCall,
return content.String(), err
}
// Recurse into CreateChatCompletionStream with the tool call replies
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else {
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) ToolCalls: convertToolCallToAPI(toolCalls),
} }, nil
} }
return content.String(), nil return &api.Message{
Role: api.MessageRoleAssistant,
Content: content.String(),
}, nil
} }

View File

@ -1,4 +1,4 @@
package model package api
import ( import (
"database/sql/driver" "database/sql/driver"
@ -6,11 +6,11 @@ import (
"fmt" "fmt"
) )
type Tool struct { type ToolSpec struct {
Name string Name string
Description string Description string
Parameters []ToolParameter Parameters []ToolParameter
Impl func(*Tool, map[string]interface{}) (string, error) Impl func(*ToolSpec, map[string]interface{}) (string, error)
} }
type ToolParameter struct { type ToolParameter struct {
@ -27,6 +27,12 @@ type ToolCall struct {
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"` Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
} }
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolCalls []ToolCall type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) { func (tc *ToolCalls) Scan(value any) (err error) {
@ -50,12 +56,6 @@ func (tc ToolCalls) Value() (driver.Value, error) {
return string(jsonBytes), nil return string(jsonBytes), nil
} }
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolResults []ToolResult type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) { func (tr *ToolResults) Scan(value any) (err error) {

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -36,7 +36,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
lastMessage := &messages[len(messages)-1] lastMessage := &messages[len(messages)-1]
if lastMessage.Role != model.MessageRoleAssistant { if lastMessage.Role != api.MessageRoleAssistant {
return fmt.Errorf("the last message in the conversation is not an assistant message") return fmt.Errorf("the last message in the conversation is not an assistant message")
} }
@ -50,7 +50,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -53,10 +53,10 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role != "" { if role != "" {
if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) { if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
} }
toEdit.Role = model.MessageRole(role) toEdit.Role = api.MessageRole(role)
} }
// Update the message in-place // Update the message in-place

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -20,19 +20,19 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []model.Message var messages []api.Message
// TODO: probably just make this part of the conversation // TODO: probably just make this part of the conversation
system := ctx.GetSystemPrompt() system := ctx.GetSystemPrompt()
if system != "" { if system != "" {
messages = append(messages, model.Message{ messages = append(messages, api.Message{
Role: model.MessageRoleSystem, Role: api.MessageRoleSystem,
Content: system, Content: system,
}) })
} }
messages = append(messages, model.Message{ messages = append(messages, api.Message{
Role: model.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -20,19 +20,19 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []model.Message var messages []api.Message
// TODO: stop supplying system prompt as a message // TODO: stop supplying system prompt as a message
system := ctx.GetSystemPrompt() system := ctx.GetSystemPrompt()
if system != "" { if system != "" {
messages = append(messages, model.Message{ messages = append(messages, api.Message{
Role: model.MessageRoleSystem, Role: api.MessageRoleSystem,
Content: system, Content: system,
}) })
} }
messages = append(messages, model.Message{ messages = append(messages, api.Message{
Role: model.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,7 +23,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
var toRemove []*model.Conversation var toRemove []*api.Conversation
for _, shortName := range args { for _, shortName := range args {
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
toRemove = append(toRemove, conversation) toRemove = append(toRemove, conversation)

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -30,8 +30,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No reply was provided.") return fmt.Errorf("No reply was provided.")
} }
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{
Role: model.MessageRoleUser, Role: api.MessageRoleUser,
Content: reply, Content: reply,
}) })
return nil return nil

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -43,11 +43,11 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
retryFromIdx := len(messages) - 1 - offset retryFromIdx := len(messages) - 1 - offset
// decrease retryFromIdx until we hit a user message // decrease retryFromIdx until we hit a user message
for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser { for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
retryFromIdx-- retryFromIdx--
} }
if messages[retryFromIdx].Role != model.MessageRoleUser { if messages[retryFromIdx].Role != api.MessageRoleUser {
return fmt.Errorf("No user messages to retry") return fmt.Errorf("No user messages to retry")
} }

View File

@ -10,36 +10,36 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
content := make(chan api.Chunk) // receives the reponse from LLM
defer close(content)
// render all content received over the channel
go ShowDelayedContent(content)
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return "", err return nil, err
} }
requestParams := model.RequestParameters{ requestParams := api.RequestParameters{
Model: m, Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
response, err := provider.CreateChatCompletionStream( content := make(chan api.Chunk)
context.Background(), requestParams, messages, callback, content, defer close(content)
// render the content received over the channel
go ShowDelayedContent(content)
reply, err := provider.CreateChatCompletionStream(
context.Background(), requestParams, messages, content,
) )
if response != "" {
if reply.Content != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
fmt.Println() fmt.Println()
@ -48,12 +48,12 @@ func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Me
err = nil err = nil
} }
} }
return response, err return reply, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
// short name or exits the program // short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation { func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err) lmcli.Fatal("Could not lookup conversation: %v\n", err)
@ -64,7 +64,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversatio
return c return c
} }
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) { func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err) return nil, fmt.Errorf("Could not lookup conversation: %v", err)
@ -75,7 +75,7 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversat
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) { func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not load messages: %v\n", err)
@ -85,7 +85,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) { func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) {
if to == nil { if to == nil {
lmcli.Fatal("Can't prompt from an empty message.") lmcli.Fatal("Can't prompt from an empty message.")
} }
@ -97,7 +97,7 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
RenderConversation(ctx, append(existing, messages...), true) RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []model.Message var savedReplies []api.Message
if persist && len(messages) > 0 { if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...) savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil { if err != nil {
@ -106,15 +106,15 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
} }
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant}))
var lastSavedMessage *model.Message var lastSavedMessage *api.Message
lastSavedMessage = to lastSavedMessage = to
if len(savedReplies) > 0 { if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1] lastSavedMessage = &savedReplies[len(savedReplies)-1]
} }
replyCallback := func(reply model.Message) { replyCallback := func(reply api.Message) {
if !persist { if !persist {
return return
} }
@ -131,16 +131,16 @@ func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages .
} }
} }
func FormatForExternalPrompt(messages []model.Message, system bool) string { func FormatForExternalPrompt(messages []api.Message, system bool) string {
sb := strings.Builder{} sb := strings.Builder{}
for _, message := range messages { for _, message := range messages {
if message.Content == "" { if message.Content == "" {
continue continue
} }
switch message.Role { switch message.Role {
case model.MessageRoleAssistant, model.MessageRoleToolCall: case api.MessageRoleAssistant, api.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n") sb.WriteString("Assistant:\n\n")
case model.MessageRoleUser: case api.MessageRoleUser:
sb.WriteString("User:\n\n") sb.WriteString("User:\n\n")
default: default:
continue continue
@ -150,7 +150,7 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) { func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) {
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
Example conversation: Example conversation:
@ -177,28 +177,32 @@ Example response:
return "", err return "", err
} }
generateRequest := []model.Message{ generateRequest := []api.Message{
{ {
Role: model.MessageRoleSystem, Role: api.MessageRoleSystem,
Content: systemPrompt, Content: systemPrompt,
}, },
{ {
Role: model.MessageRoleUser, Role: api.MessageRoleUser,
Content: string(conversation), Content: string(conversation),
}, },
} }
m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel) m, provider, err := ctx.GetModelProvider(
*ctx.Config.Conversations.TitleGenerationModel,
)
if err != nil { if err != nil {
return "", err return "", err
} }
requestParams := model.RequestParameters{ requestParams := api.RequestParameters{
Model: m, Model: m,
MaxTokens: 25, MaxTokens: 25,
} }
response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) response, err := provider.CreateChatCompletion(
context.Background(), requestParams, generateRequest,
)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -207,7 +211,7 @@ Example response:
var jsonResponse struct { var jsonResponse struct {
Title string `json:"title"` Title string `json:"title"`
} }
err = json.Unmarshal([]byte(response), &jsonResponse) err = json.Unmarshal([]byte(response.Content), &jsonResponse)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -272,7 +276,7 @@ func ShowDelayedContent(content <-chan api.Chunk) {
// RenderConversation renders the given messages to TTY, with optional space // RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters // for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true) // are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) { func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) {
l := len(messages) l := len(messages)
for i, message := range messages { for i, message := range messages {
RenderMessage(ctx, &message) RenderMessage(ctx, &message)
@ -283,7 +287,7 @@ func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForRe
} }
} }
func RenderMessage(ctx *lmcli.Context, m *model.Message) { func RenderMessage(ctx *lmcli.Context, m *api.Message) {
var messageAge string var messageAge string
if m.CreatedAt.IsZero() { if m.CreatedAt.IsZero() {
messageAge = "now" messageAge = "now"
@ -295,11 +299,11 @@ func RenderMessage(ctx *lmcli.Context, m *model.Message) {
headerStyle := lipgloss.NewStyle().Bold(true) headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role { switch m.Role {
case model.MessageRoleSystem: case api.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case model.MessageRoleUser: case api.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case model.MessageRoleAssistant: case api.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
} }

View File

@ -6,13 +6,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/agent"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@ -24,7 +23,7 @@ type Context struct {
Store ConversationStore Store ConversationStore
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool EnabledTools []api.ToolSpec
SystemPromptFile string SystemPromptFile string
} }
@ -50,9 +49,9 @@ func NewContext() (*Context, error) {
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
var enabledTools []model.Tool var enabledTools []api.ToolSpec
for _, toolName := range config.Tools.EnabledTools { for _, toolName := range config.Tools.EnabledTools {
tool, ok := tools.AvailableTools[toolName] tool, ok := agent.AvailableTools[toolName]
if ok { if ok {
enabledTools = append(enabledTools, tool) enabledTools = append(enabledTools, tool)
} }
@ -79,7 +78,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) { func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
var provider string var provider string

View File

@ -8,32 +8,32 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
sqids "github.com/sqids/sqids-go" sqids "github.com/sqids/sqids-go"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ConversationStore interface { type ConversationStore interface {
ConversationByShortName(shortName string) (*model.Conversation, error) ConversationByShortName(shortName string) (*api.Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]model.Message, error) RootMessages(conversationID uint) ([]api.Message, error)
LatestConversationMessages() ([]model.Message, error) LatestConversationMessages() ([]api.Message, error)
StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
UpdateConversation(conversation *model.Conversation) error UpdateConversation(conversation *api.Conversation) error
DeleteConversation(conversation *model.Conversation) error DeleteConversation(conversation *api.Conversation) error
CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
MessageByID(messageID uint) (*model.Message, error) MessageByID(messageID uint) (*api.Message, error)
MessageReplies(messageID uint) ([]model.Message, error) MessageReplies(messageID uint) ([]api.Message, error)
UpdateMessage(message *model.Message) error UpdateMessage(message *api.Message) error
DeleteMessage(message *model.Message, prune bool) error DeleteMessage(message *api.Message, prune bool) error
CloneBranch(toClone model.Message) (*model.Message, uint, error) CloneBranch(toClone api.Message) (*api.Message, uint, error)
Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
PathToRoot(message *model.Message) ([]model.Message, error) PathToRoot(message *api.Message) ([]api.Message, error)
PathToLeaf(message *model.Message) ([]model.Message, error) PathToLeaf(message *api.Message) ([]api.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -43,8 +43,8 @@ type SQLStore struct {
func NewSQLStore(db *gorm.DB) (*SQLStore, error) { func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{ models := []any{
&model.Conversation{}, &api.Conversation{},
&model.Message{}, &api.Message{},
} }
for _, x := range models { for _, x := range models {
@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) createConversation() (*model.Conversation, error) { func (s *SQLStore) createConversation() (*api.Conversation, error) {
// Create the new conversation // Create the new conversation
c := &model.Conversation{} c := &api.Conversation{}
err := s.db.Save(c).Error err := s.db.Save(c).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*model.Conversation, error) {
return c, nil return c, nil
} }
func (s *SQLStore) UpdateConversation(c *model.Conversation) error { func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(c).Error return s.db.Updates(c).Error
} }
func (s *SQLStore) DeleteConversation(c *model.Conversation) error { func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
// Delete messages first // Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(c).Error return s.db.Delete(c).Error
} }
func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error { func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
panic("Not yet implemented") panic("Not yet implemented")
//return s.db.Delete(&message).Error //return s.db.Delete(&message).Error
} }
func (s *SQLStore) UpdateMessage(m *model.Message) error { func (s *SQLStore) UpdateMessage(m *api.Message) error {
if m == nil || m.ID == 0 { if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)") return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *model.Message) error {
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var conversations []model.Conversation var conversations []api.Conversation
// ignore error for completions // ignore error for completions
s.db.Find(&conversations) s.db.Find(&conversations)
completions := make([]string, 0, len(conversations)) completions := make([]string, 0, len(conversations))
@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
return completions return completions
} }
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) { func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
if shortName == "" { if shortName == "" {
return nil, errors.New("shortName is empty") return nil, errors.New("shortName is empty")
} }
var conversation model.Conversation var conversation api.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err return &conversation, err
} }
func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) { func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
var rootMessages []model.Message var rootMessages []api.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
return rootMessages, nil return rootMessages, nil
} }
func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) { func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
var message model.Message var message api.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err return &message, err
} }
func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) { func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
var replies []model.Message var replies []api.Message
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
return replies, err return replies, err
} }
// StartConversation starts a new conversation with the provided messages // StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) { func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message") return nil, nil, fmt.Errorf("Must provide at least 1 message")
} }
@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversa
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
messages = append([]model.Message{messages[0]}, newMessages...) messages = append([]api.Message{messages[0]}, newMessages...)
} }
return conversation, messages, nil return conversation, messages, nil
} }
// CloneConversation clones the given conversation and all of its root meesages // CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) { func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID) rootMessages, err := s.RootMessages(toClone.ID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Convers
} }
// Reply to a message with a series of messages (each following the next) // Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) { func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
var savedMessages []model.Message var savedMessages []api.Message
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to currentParent := to
@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.
// CloneBranch returns a deep clone of the given message and its replies, returning // CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as // a new message object. The new message will be attached to the same parent as
// the messageToClone // the messageToClone
func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) { func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
newMessage := messageToClone newMessage := messageToClone
newMessage.ID = 0 newMessage.ID = 0
newMessage.Replies = nil newMessage.Replies = nil
@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, ui
return &newMessage, replyCount, nil return &newMessage, replyCount, nil
} }
func fetchMessages(db *gorm.DB) ([]model.Message, error) { func fetchMessages(db *gorm.DB) ([]api.Message, error) {
var messages []model.Message var messages []api.Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil { if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err) return nil, fmt.Errorf("Could not fetch messages: %v", err)
} }
messageMap := make(map[uint]model.Message) messageMap := make(map[uint]api.Message)
for i, message := range messages { for i, message := range messages {
messageMap[messages[i].ID] = message messageMap[messages[i].ID] = message
} }
// Create a map to store replies by their parent ID // Create a map to store replies by their parent ID
repliesMap := make(map[uint][]model.Message) repliesMap := make(map[uint][]api.Message)
for i, message := range messages { for i, message := range messages {
if messages[i].ParentID != nil { if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
// Assign replies, parent, and selected reply to each message // Assign replies, parent, and selected reply to each message
for i := range messages { for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists { if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]model.Message, len(replies)) messages[i].Replies = make([]api.Message, len(replies))
for j, m := range replies { for j, m := range replies {
messages[i].Replies[j] = m messages[i].Replies[j] = m
} }
@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]model.Message, error) {
return messages, nil return messages, nil
} }
func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) { func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
var messages []model.Message var messages []api.Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Create a map to store messages by their ID // Create a map to store messages by their ID
messageMap := make(map[uint]*model.Message) messageMap := make(map[uint]*api.Message)
for i := range messages { for i := range messages {
messageMap[messages[i].ID] = &messages[i] messageMap[messages[i].ID] = &messages[i]
} }
// Build the path // Build the path
var path []model.Message var path []api.Message
nextID := &message.ID nextID := &message.ID
for { for {
@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message
// PathToRoot traverses the provided message's Parent until reaching the tree // PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order // root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided) // (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) { func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
path, err := s.buildPath(message, func(m *model.Message) *uint { path, err := s.buildPath(message, func(m *api.Message) *uint {
return m.ParentID return m.ParentID
}) })
if err != nil { if err != nil {
@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
// PathToLeaf traverses the provided message's SelectedReply until reaching a // PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological // tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf) // order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) { func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
return s.buildPath(message, func(m *model.Message) *uint { return s.buildPath(message, func(m *api.Message) *uint {
return m.SelectedReplyID return m.SelectedReplyID
}) })
} }
func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) { func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
var latestMessages []model.Message var latestMessages []api.Message
subQuery := s.db.Model(&model.Message{}). subQuery := s.db.Model(&api.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id"). Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id") Group("conversation_id")
err := s.db.Model(&model.Message{}). err := s.db.Model(&api.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id"). Group("messages.conversation_id").
Order("created_at DESC"). Order("created_at DESC").

View File

@ -1,48 +0,0 @@
package tools
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
var AvailableTools map[string]model.Tool = map[string]model.Tool{
"dir_tree": DirTreeTool,
"read_dir": ReadDirTool,
"read_file": ReadFileTool,
"write_file": WriteFileTool,
"file_insert_lines": FileInsertLinesTool,
"file_replace_lines": FileReplaceLinesTool,
}
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
var toolResults []model.ToolResult
for _, toolCall := range toolCalls {
var tool *model.Tool
for _, available := range toolBag {
if available.Name == toolCall.Name {
tool = &available
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
}
// Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
}
toolResult := model.ToolResult{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -4,7 +4,6 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
@ -16,37 +15,39 @@ import (
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM
msgResponseChunk api.Chunk
// sent when response is finished being received
msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation
msgResponseError error
// sent on each completed reply
msgResponse models.Message
// sent when a conversation is (re)loaded // sent when a conversation is (re)loaded
msgConversationLoaded struct { msgConversationLoaded struct {
conversation *models.Conversation conversation *api.Conversation
rootMessages []models.Message rootMessages []api.Message
} }
// sent when a new conversation title generated // sent when a new conversation title generated
msgConversationTitleGenerated string msgConversationTitleGenerated string
// sent when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted struct {
isNew bool isNew bool
conversation *models.Conversation conversation *api.Conversation
messages []models.Message messages []api.Message
} }
// sent when a conversation's messages are laoded
msgMessagesLoaded []api.Message
// a special case of common.MsgError that stops the response waiting animation
msgChatResponseError error
// sent on each chunk received from LLM
msgChatResponseChunk api.Chunk
// sent on each completed reply
msgChatResponse *api.Message
// sent when the response is canceled
msgChatResponseCanceled struct{}
// sent when results from a tool call are returned
msgToolResults []api.ToolResult
// sent when the given message is made the new selected reply of its parent // sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *models.Message msgSelectedReplyCycled *api.Message
// sent when the given message is made the new selected root of the current conversation // sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *models.Message msgSelectedRootCycled *api.Message
// sent when a message's contents are updated and saved // sent when a message's contents are updated and saved
msgMessageUpdated *models.Message msgMessageUpdated *api.Message
// sent when a message is cloned, with the cloned message // sent when a message is cloned, with the cloned message
msgMessageCloned *models.Message msgMessageCloned *api.Message
) )
type focusState int type focusState int
@ -77,14 +78,14 @@ type Model struct {
// app state // app state
state state // current overall status of the view state state // current overall status of the view
conversation *models.Conversation conversation *api.Conversation
rootMessages []models.Message rootMessages []api.Message
messages []models.Message messages []api.Message
selectedMessage int selectedMessage int
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan models.Message replyChan chan api.Message
replyChunkChan chan api.Chunk chatReplyChunks chan api.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
@ -111,12 +112,12 @@ func Chat(shared shared.Shared) Model {
Shared: shared, Shared: shared,
state: idle, state: idle,
conversation: &models.Conversation{}, conversation: &api.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan models.Message), replyChan: make(chan api.Message),
replyChunkChan: make(chan api.Chunk), chatReplyChunks: make(chan api.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,
@ -144,8 +145,8 @@ func Chat(shared shared.Shared) Model {
system := shared.Ctx.GetSystemPrompt() system := shared.Ctx.GetSystemPrompt()
if system != "" { if system != "" {
m.messages = []models.Message{{ m.messages = []api.Message{{
Role: models.MessageRoleSystem, Role: api.MessageRoleSystem,
Content: system, Content: system,
}} }}
} }
@ -166,6 +167,5 @@ func Chat(shared shared.Shared) Model {
func (m Model) Init() tea.Cmd { func (m Model) Init() tea.Cmd {
return tea.Batch( return tea.Batch(
m.waitForResponseChunk(), m.waitForResponseChunk(),
m.waitForResponse(),
) )
} }

View File

@ -2,16 +2,18 @@ package chat
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/agent"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
) )
func (m *Model) setMessage(i int, msg models.Message) { func (m *Model) setMessage(i int, msg api.Message) {
if i >= len(m.messages) { if i >= len(m.messages) {
panic("i out of range") panic("i out of range")
} }
@ -19,7 +21,7 @@ func (m *Model) setMessage(i int, msg models.Message) {
m.messageCache[i] = m.renderMessage(i) m.messageCache[i] = m.renderMessage(i)
} }
func (m *Model) addMessage(msg models.Message) { func (m *Model) addMessage(msg api.Message) {
m.messages = append(m.messages, msg) m.messages = append(m.messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1)) m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
} }
@ -88,7 +90,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
} }
} }
func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd { func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateConversation(conversation) err := m.Shared.Ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
@ -101,7 +103,7 @@ func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.C
// Clones the given message (and its descendents). If selected is true, updates // Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to // either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone // point to the new clone
func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd { func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message) msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil { if err != nil {
@ -123,7 +125,7 @@ func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
} }
} }
func (m *Model) updateMessageContent(message *models.Message) tea.Cmd { func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateMessage(message) err := m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
@ -133,7 +135,7 @@ func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
} }
} }
func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) { func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
currentIndex := -1 currentIndex := -1
for i, reply := range choices { for i, reply := range choices {
if reply.ID == selected.ID { if reply.ID == selected.ID {
@ -158,7 +160,7 @@ func cycleSelectedMessage(selected *models.Message, choices []models.Message, di
return &choices[next], nil return &choices[next], nil
} }
func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd { func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 { if len(m.rootMessages) < 2 {
return nil return nil
} }
@ -178,7 +180,7 @@ func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDir
} }
} }
func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd { func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 { if len(message.Replies) < 2 {
return nil return nil
} }
@ -218,15 +220,12 @@ func (m *Model) persistConversation() tea.Cmd {
// else, we'll handle updating an existing conversation's messages // else, we'll handle updating an existing conversation's messages
for i := range messages { for i := range messages {
if messages[i].ID > 0 { if messages[i].ID > 0 {
// message has an ID, update its contents // message has an ID, update it
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i]) err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
} else if i > 0 { } else if i > 0 {
if messages[i].Content == "" {
continue
}
// messages is new, so add it as a reply to previous message // messages is new, so add it as a reply to previous message
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i]) saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil { if err != nil {
@ -243,13 +242,23 @@ func (m *Model) persistConversation() tea.Cmd {
} }
} }
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
results, err := agent.ExecuteToolCalls(toolCalls, m.Ctx.EnabledTools)
if err != nil {
return shared.MsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd { func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse m.state = pendingResponse
m.replyCursor.Blink = false m.replyCursor.Blink = false
m.tokenCount = 0
m.startTime = time.Now() m.startTime = time.Now()
m.elapsed = 0 m.elapsed = 0
m.tokenCount = 0
return func() tea.Msg { return func() tea.Msg {
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model) model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
@ -257,36 +266,34 @@ func (m *Model) promptLLM() tea.Cmd {
return shared.MsgError(err) return shared.MsgError(err)
} }
requestParams := models.RequestParameters{ requestParams := api.RequestParameters{
Model: model, Model: model,
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens, MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature, Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
ToolBag: m.Shared.Ctx.EnabledTools, ToolBag: m.Shared.Ctx.EnabledTools,
} }
replyHandler := func(msg models.Message) {
m.replyChan <- msg
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
canceled := false
go func() { go func() {
select { select {
case <-m.stopSignal: case <-m.stopSignal:
canceled = true
cancel() cancel()
} }
}() }()
resp, err := provider.CreateChatCompletionStream( resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, ctx, requestParams, m.messages, m.chatReplyChunks,
) )
if err != nil && !canceled { if errors.Is(err, context.Canceled) {
return msgResponseError(err) return msgChatResponseCanceled(struct{}{})
} }
return msgResponseEnd(resp) if err != nil {
return msgChatResponseError(err)
}
return msgChatResponse(resp)
} }
} }

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
@ -150,12 +150,12 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return true, nil return true, nil
} }
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser { if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser {
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message")) return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
} }
m.addMessage(models.Message{ m.addMessage(api.Message{
Role: models.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -4,7 +4,7 @@ import (
"strings" "strings"
"time" "time"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
@ -21,15 +21,9 @@ func (m *Model) HandleResize(width, height int) {
} }
} }
func (m *Model) waitForResponse() tea.Cmd {
return func() tea.Msg {
return msgResponse(<-m.replyChan)
}
}
func (m *Model) waitForResponseChunk() tea.Cmd { func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
return msgResponseChunk(<-m.replyChunkChan) return msgChatResponseChunk(<-m.chatReplyChunks)
} }
} }
@ -48,7 +42,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname { if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
// clear existing messages if we're loading a new conversation // clear existing messages if we're loading a new conversation
m.messages = []models.Message{} m.messages = []api.Message{}
m.selectedMessage = 0 m.selectedMessage = 0
} }
} }
@ -87,7 +81,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgResponseChunk: case msgChatResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" { if msg.Content == "" {
@ -100,8 +94,8 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.setMessageContents(last, m.messages[last].Content+msg.Content) m.setMessageContents(last, m.messages[last].Content+msg.Content)
} else { } else {
// use chunk in new message // use chunk in new message
m.addMessage(models.Message{ m.addMessage(api.Message{
Role: models.MessageRoleAssistant, Role: api.MessageRoleAssistant,
Content: msg.Content, Content: msg.Content,
}) })
} }
@ -113,10 +107,10 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.tokenCount += msg.TokenCount m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime) m.elapsed = time.Now().Sub(m.startTime)
case msgResponse: case msgChatResponse:
cmds = append(cmds, m.waitForResponse()) // wait for the next response m.state = idle
reply := models.Message(msg) reply := (*api.Message)(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1 last := len(m.messages) - 1
@ -124,11 +118,18 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
panic("Unexpected empty messages handling msgAssistantReply") panic("Unexpected empty messages handling msgAssistantReply")
} }
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() { if m.messages[last].Role.IsAssistant() {
// this was a continuation, so replace the previous message with the completed reply // TODO: handle continuations gracefully - some models support them well, others fail horribly.
m.setMessage(last, reply) m.setMessage(last, *reply)
} else { } else {
m.addMessage(reply) m.addMessage(*reply)
}
switch reply.Role {
case api.MessageRoleToolCall:
// TODO: user confirmation before execution
// m.state = waitingForConfirmation
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
} }
if m.persistence { if m.persistence {
@ -140,17 +141,32 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
m.updateContent() m.updateContent()
case msgResponseEnd: case msgChatResponseCanceled:
m.state = idle m.state = idle
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent() m.updateContent()
case msgResponseError: case msgChatResponseError:
m.state = idle m.state = idle
m.Shared.Err = error(msg) m.Shared.Err = error(msg)
m.updateContent()
case msgToolResults:
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if m.messages[last].Role != api.MessageRoleToolCall {
panic("Previous message not a tool call, unexpected")
}
m.addMessage(api.Message{
Role: api.MessageRoleToolResult,
ToolResults: api.ToolResults(msg),
})
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
m.updateContent() m.updateContent()
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
@ -167,7 +183,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.conversation = msg.conversation m.conversation = msg.conversation
m.messages = msg.messages m.messages = msg.messages
if msg.isNew { if msg.isNew {
m.rootMessages = []models.Message{m.messages[0]} m.rootMessages = []api.Message{m.messages[0]}
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@ -63,22 +63,22 @@ func (m Model) View() string {
return lipgloss.JoinVertical(lipgloss.Left, sections...) return lipgloss.JoinVertical(lipgloss.Left, sections...)
} }
func (m *Model) renderMessageHeading(i int, message *models.Message) string { func (m *Model) renderMessageHeading(i int, message *api.Message) string {
icon := "" icon := ""
friendly := message.Role.FriendlyRole() friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Faint(true).Bold(true) style := lipgloss.NewStyle().Faint(true).Bold(true)
switch message.Role { switch message.Role {
case models.MessageRoleSystem: case api.MessageRoleSystem:
icon = "⚙️" icon = "⚙️"
case models.MessageRoleUser: case api.MessageRoleUser:
style = userStyle style = userStyle
case models.MessageRoleAssistant: case api.MessageRoleAssistant:
style = assistantStyle style = assistantStyle
case models.MessageRoleToolCall: case api.MessageRoleToolCall:
style = assistantStyle style = assistantStyle
friendly = models.MessageRoleAssistant.FriendlyRole() friendly = api.MessageRoleAssistant.FriendlyRole()
case models.MessageRoleToolResult: case api.MessageRoleToolResult:
icon = "🔧" icon = "🔧"
} }
@ -139,21 +139,21 @@ func (m *Model) renderMessage(i int) string {
} }
// Show the assistant's cursor // Show the assistant's cursor
if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant { if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == api.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }
// Write tool call info // Write tool call info
var toolString string var toolString string
switch msg.Role { switch msg.Role {
case models.MessageRoleToolCall: case api.MessageRoleToolCall:
bytes, err := yaml.Marshal(msg.ToolCalls) bytes, err := yaml.Marshal(msg.ToolCalls)
if err != nil { if err != nil {
toolString = "Could not serialize ToolCalls" toolString = "Could not serialize ToolCalls"
} else { } else {
toolString = "tool_calls:\n" + string(bytes) toolString = "tool_calls:\n" + string(bytes)
} }
case models.MessageRoleToolResult: case api.MessageRoleToolResult:
if !m.showToolResults { if !m.showToolResults {
break break
} }
@ -221,11 +221,11 @@ func (m *Model) conversationMessagesView() string {
m.messageOffsets[i] = lineCnt m.messageOffsets[i] = lineCnt
switch message.Role { switch message.Role {
case models.MessageRoleToolCall: case api.MessageRoleToolCall:
if !m.showToolResults && message.Content == "" { if !m.showToolResults && message.Content == "" {
continue continue
} }
case models.MessageRoleToolResult: case api.MessageRoleToolResult:
if !m.showToolResults { if !m.showToolResults {
continue continue
} }
@ -251,9 +251,9 @@ func (m *Model) conversationMessagesView() string {
} }
// Render a placeholder for the incoming assistant reply // Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) { if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &models.Message{ heading := m.renderMessageHeading(-1, &api.Message{
Role: models.MessageRoleAssistant, Role: api.MessageRoleAssistant,
}) })
sb.WriteString(heading) sb.WriteString(heading)
sb.WriteString("\n") sb.WriteString("\n")

View File

@ -5,7 +5,7 @@ import (
"strings" "strings"
"time" "time"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
@ -16,15 +16,15 @@ import (
) )
type loadedConversation struct { type loadedConversation struct {
conv models.Conversation conv api.Conversation
lastReply models.Message lastReply api.Message
} }
type ( type (
// sent when conversation list is loaded // sent when conversation list is loaded
msgConversationsLoaded ([]loadedConversation) msgConversationsLoaded ([]loadedConversation)
// sent when a conversation is selected // sent when a conversation is selected
msgConversationSelected models.Conversation msgConversationSelected api.Conversation
) )
type Model struct { type Model struct {