Compare commits

..

No commits in common. "31df055430660fb3280c0c27490161a34af0cde1" and "85a2abbbf3bbf060675496845f6dcfd4dcb305ad" have entirely different histories.

35 changed files with 784 additions and 636 deletions

View File

@ -1,48 +0,0 @@
package agent
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/agent/toolbox"
"git.mlow.ca/mlow/lmcli/pkg/api"
)
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
"dir_tree": toolbox.DirTreeTool,
"read_dir": toolbox.ReadDirTool,
"read_file": toolbox.ReadFileTool,
"write_file": toolbox.WriteFileTool,
"file_insert_lines": toolbox.FileInsertLinesTool,
"file_replace_lines": toolbox.FileReplaceLinesTool,
}
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
var toolResults []api.ToolResult
for _, call := range calls {
var tool *api.ToolSpec
for i := range available {
if available[i].Name == call.Name {
tool = &available[i]
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
}
// Execute the tool
result, err := tool.Impl(tool, call.Parameters)
if err != nil {
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
}
toolResult := api.ToolResult{
ToolCallID: call.ID,
ToolName: call.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -2,41 +2,35 @@ package api
import ( import (
"context" "context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type ReplyCallback func(Message) type ReplyCallback func(model.Message)
type Chunk struct { type Chunk struct {
Content string Content string
TokenCount uint TokenCount uint
} }
type RequestParameters struct { type ChatCompletionClient interface {
Model string
MaxTokens int
Temperature float32
TopP float32
ToolBag []ToolSpec
}
type ChatCompletionProvider interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string. // complete user-facing response is returned as a string.
CreateChatCompletion( CreateChatCompletion(
ctx context.Context, ctx context.Context,
params RequestParameters, params model.RequestParameters,
messages []Message, messages []model.Message,
) (*Message, error) callback ReplyCallback,
) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received. // the output channel as it's received.
CreateChatCompletionStream( CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params RequestParameters, params model.RequestParameters,
messages []Message, messages []model.Message,
chunks chan<- Chunk, callback ReplyCallback,
) (*Message, error) output chan<- Chunk,
) (string, error)
} }

View File

@ -1,11 +0,0 @@
package api
import "database/sql"
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}

View File

@ -11,9 +11,11 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func buildRequest(params api.RequestParameters, messages []api.Message) Request { func buildRequest(params model.RequestParameters, messages []model.Message) Request {
requestBody := Request{ requestBody := Request{
Model: params.Model, Model: params.Model,
Messages: make([]Message, len(messages)), Messages: make([]Message, len(messages)),
@ -28,7 +30,7 @@ func buildRequest(params api.RequestParameters, messages []api.Message) Request
} }
startIdx := 0 startIdx := 0
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem { if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:] requestBody.Messages = requestBody.Messages[1:]
startIdx = 1 startIdx = 1
@ -46,7 +48,7 @@ func buildRequest(params api.RequestParameters, messages []api.Message) Request
message := &requestBody.Messages[i] message := &requestBody.Messages[i]
switch msg.Role { switch msg.Role {
case api.MessageRoleToolCall: case model.MessageRoleToolCall:
message.Role = "assistant" message.Role = "assistant"
if msg.Content != "" { if msg.Content != "" {
message.Content = msg.Content message.Content = msg.Content
@ -61,7 +63,7 @@ func buildRequest(params api.RequestParameters, messages []api.Message) Request
} else { } else {
message.Content = xmlString message.Content = xmlString
} }
case api.MessageRoleToolResult: case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString() xmlString, err := xmlFuncResults.XMLString()
if err != nil { if err != nil {
@ -103,25 +105,26 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) (*api.Message, error) { callback api.ReplyCallback,
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
var response Response var response Response
err = json.NewDecoder(resp.Body).Decode(&response) err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode response: %v", err) return "", fmt.Errorf("failed to decode response: %v", err)
} }
sb := strings.Builder{} sb := strings.Builder{}
@ -134,28 +137,34 @@ func (c *AnthropicClient) CreateChatCompletion(
} }
for _, content := range response.Content { for _, content := range response.Content {
var reply model.Message
switch content.Type { switch content.Type {
case "text": case "text":
sb.WriteString(content.Text) reply = model.Message{
Role: model.MessageRoleAssistant,
Content: content.Text,
}
sb.WriteString(reply.Content)
default: default:
return nil, fmt.Errorf("unsupported message type: %s", content.Type) return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
} }
} }
return &api.Message{ return sb.String(), nil
Role: api.MessageRoleAssistant,
Content: sb.String(),
}, nil
} }
func (c *AnthropicClient) CreateChatCompletionStream( func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (*api.Message, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -163,18 +172,19 @@ func (c *AnthropicClient) CreateChatCompletionStream(
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
sb := strings.Builder{} sb := strings.Builder{}
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() { if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll // this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result // include its contents in the final result
// TODO: handle this at higher level
sb.WriteString(lastMessage.Content) sb.WriteString(lastMessage.Content)
continuation = true
} }
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
@ -190,29 +200,29 @@ func (c *AnthropicClient) CreateChatCompletionStream(
var event map[string]interface{} var event map[string]interface{}
err := json.Unmarshal([]byte(line), &event) err := json.Unmarshal([]byte(line), &event)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err) return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
} }
eventType, ok := event["type"].(string) eventType, ok := event["type"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid event: %s", line) return "", fmt.Errorf("invalid event: %s", line)
} }
switch eventType { switch eventType {
case "error": case "error":
return nil, fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default: default:
return nil, fmt.Errorf("unknown event type: %s", eventType) return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
} }
} else if strings.HasPrefix(line, "data:") { } else if strings.HasPrefix(line, "data:") {
data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var event map[string]interface{} var event map[string]interface{}
err := json.Unmarshal([]byte(data), &event) err := json.Unmarshal([]byte(data), &event)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal event data: %v", err) return "", fmt.Errorf("failed to unmarshal event data: %v", err)
} }
eventType, ok := event["type"].(string) eventType, ok := event["type"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid event type") return "", fmt.Errorf("invalid event type")
} }
switch eventType { switch eventType {
@ -225,11 +235,11 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "content_block_delta": case "content_block_delta":
delta, ok := event["delta"].(map[string]interface{}) delta, ok := event["delta"].(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("invalid content block delta") return "", fmt.Errorf("invalid content block delta")
} }
text, ok := delta["text"].(string) text, ok := delta["text"].(string)
if !ok { if !ok {
return nil, fmt.Errorf("invalid text delta") return "", fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- api.Chunk{ output <- api.Chunk{
@ -241,7 +251,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
case "message_delta": case "message_delta":
delta, ok := event["delta"].(map[string]interface{}) delta, ok := event["delta"].(map[string]interface{})
if !ok { if !ok {
return nil, fmt.Errorf("invalid message delta") return "", fmt.Errorf("invalid message delta")
} }
stopReason, ok := delta["stop_reason"].(string) stopReason, ok := delta["stop_reason"].(string)
if ok && stopReason == "stop_sequence" { if ok && stopReason == "stop_sequence" {
@ -251,7 +261,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
start := strings.Index(content, "<function_calls>") start := strings.Index(content, "<function_calls>")
if start == -1 { if start == -1 {
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
@ -259,31 +269,59 @@ func (c *AnthropicClient) CreateChatCompletionStream(
Content: FUNCTION_STOP_SEQUENCE, Content: FUNCTION_STOP_SEQUENCE,
TokenCount: 1, TokenCount: 1,
} }
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
var functionCalls XMLFunctionCalls var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err) return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
} }
return &api.Message{ toolCall := model.Message{
Role: api.MessageRoleToolCall, Role: model.MessageRoleToolCall,
// function call xml stripped from content for model interop // function call xml stripped from content for model interop
Content: strings.TrimSpace(content[:start]), Content: strings.TrimSpace(content[:start]),
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
}, nil }
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return "", err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
content := sb.String() content := sb.String()
return &api.Message{ if callback != nil {
Role: api.MessageRoleAssistant, callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}, nil })
}
return content, nil
case "error": case "error":
return nil, fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default: default:
fmt.Printf("\nUnrecognized event: %s\n", data) fmt.Printf("\nUnrecognized event: %s\n", data)
} }
@ -291,8 +329,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read response body: %v", err) return "", fmt.Errorf("failed to read response body: %v", err)
} }
return nil, fmt.Errorf("unexpected end of stream") return "", fmt.Errorf("unexpected end of stream")
} }

View File

@ -6,7 +6,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
const FUNCTION_STOP_SEQUENCE = "</function_calls>" const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@ -97,7 +97,7 @@ func parseFunctionParametersXML(params string) map[string]interface{} {
return ret return ret
} }
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools { func convertToolsToXMLTools(tools []model.Tool) XMLTools {
converted := make([]XMLToolDescription, len(tools)) converted := make([]XMLToolDescription, len(tools))
for i, tool := range tools { for i, tool := range tools {
converted[i].ToolName = tool.Name converted[i].ToolName = tool.Name
@ -117,8 +117,8 @@ func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
} }
} }
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall { func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke)) toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
for i, invoke := range functionCalls.Invoke { for i, invoke := range functionCalls.Invoke {
toolCalls[i].Name = invoke.ToolName toolCalls[i].Name = invoke.ToolName
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String) toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
@ -126,7 +126,7 @@ func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.To
return toolCalls return toolCalls
} }
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls { func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
converted := make([]XMLFunctionInvoke, len(toolCalls)) converted := make([]XMLFunctionInvoke, len(toolCalls))
for i, toolCall := range toolCalls { for i, toolCall := range toolCalls {
var params XMLFunctionInvokeParameters var params XMLFunctionInvokeParameters
@ -145,7 +145,7 @@ func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCal
} }
} }
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults { func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
converted := make([]XMLFunctionResult, len(toolResults)) converted := make([]XMLFunctionResult, len(toolResults))
for i, result := range toolResults { for i, result := range toolResults {
converted[i].ToolName = result.ToolName converted[i].ToolName = result.ToolName
@ -156,11 +156,11 @@ func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunc
} }
} }
func buildToolsSystemPrompt(tools []api.ToolSpec) string { func buildToolsSystemPrompt(tools []model.Tool) string {
xmlTools := convertToolsToXMLTools(tools) xmlTools := convertToolsToXMLTools(tools)
xmlToolsString, err := xmlTools.XMLString() xmlToolsString, err := xmlTools.XMLString()
if err != nil { if err != nil {
panic("Could not serialize []api.Tool to XMLTools") panic("Could not serialize []model.Tool to XMLTools")
} }
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
} }

View File

@ -11,9 +11,11 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func convertTools(tools []api.ToolSpec) []Tool { func convertTools(tools []model.Tool) []Tool {
geminiTools := make([]Tool, len(tools)) geminiTools := make([]Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
params := make(map[string]ToolParameter) params := make(map[string]ToolParameter)
@ -48,7 +50,7 @@ func convertTools(tools []api.ToolSpec) []Tool {
return geminiTools return geminiTools
} }
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart { func convertToolCallToGemini(toolCalls []model.ToolCall) []ContentPart {
converted := make([]ContentPart, len(toolCalls)) converted := make([]ContentPart, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
args := make(map[string]string) args := make(map[string]string)
@ -63,8 +65,8 @@ func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
return converted return converted
} }
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall { func convertToolCallToAPI(functionCalls []FunctionCall) []model.ToolCall {
converted := make([]api.ToolCall, len(functionCalls)) converted := make([]model.ToolCall, len(functionCalls))
for i, call := range functionCalls { for i, call := range functionCalls {
params := make(map[string]interface{}) params := make(map[string]interface{})
for k, v := range call.Args { for k, v := range call.Args {
@ -76,7 +78,7 @@ func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
return converted return converted
} }
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) { func convertToolResultsToGemini(toolResults []model.ToolResult) ([]FunctionResponse, error) {
results := make([]FunctionResponse, len(toolResults)) results := make([]FunctionResponse, len(toolResults))
for i, result := range toolResults { for i, result := range toolResults {
var obj interface{} var obj interface{}
@ -93,14 +95,14 @@ func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionRespons
} }
func createGenerateContentRequest( func createGenerateContentRequest(
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) (*GenerateContentRequest, error) { ) (*GenerateContentRequest, error) {
requestContents := make([]Content, 0, len(messages)) requestContents := make([]Content, 0, len(messages))
startIdx := 0 startIdx := 0
var system string var system string
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem { if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
system = messages[0].Content system = messages[0].Content
startIdx = 1 startIdx = 1
} }
@ -133,9 +135,9 @@ func createGenerateContentRequest(
default: default:
var role string var role string
switch m.Role { switch m.Role {
case api.MessageRoleAssistant: case model.MessageRoleAssistant:
role = "model" role = "model"
case api.MessageRoleUser: case model.MessageRoleUser:
role = "user" role = "user"
} }
@ -181,15 +183,56 @@ func createGenerateContentRequest(
return request, nil return request, nil
} }
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { func handleToolCalls(
req.Header.Set("Content-Type", "application/json") params model.RequestParameters,
content string,
toolCalls []model.ToolCall,
callback api.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
client := &http.Client{} toolCall := model.Message{
resp, err := client.Do(req) Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: toolCalls,
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil { if err != nil {
return nil, err return nil, err
} }
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *Client) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req.WithContext(ctx))
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body) bytes, _ := io.ReadAll(resp.Body)
return resp, fmt.Errorf("%v", string(bytes)) return resp, fmt.Errorf("%v", string(bytes))
@ -200,41 +243,42 @@ func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
func (c *Client) CreateChatCompletion( func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) (*api.Message, error) { callback api.ReplyCallback,
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req, err := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil { if err != nil {
return nil, err return "", err
} }
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
url := fmt.Sprintf( url := fmt.Sprintf(
"%s/v1beta/models/%s:generateContent?key=%s", "%s/v1beta/models/%s:generateContent?key=%s",
c.BaseURL, params.Model, c.APIKey, c.BaseURL, params.Model, c.APIKey,
) )
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp GenerateContentResponse var completionResp GenerateContentResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return nil, err return "", err
} }
choice := completionResp.Candidates[0] choice := completionResp.Candidates[0]
@ -257,50 +301,58 @@ func (c *Client) CreateChatCompletion(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ messages, err := handleToolCalls(
Role: api.MessageRoleToolCall, params, content, convertToolCallToAPI(toolCalls), callback, messages,
Content: content, )
ToolCalls: convertToolCallToAPI(toolCalls), if err != nil {
}, nil return content, err
} }
return &api.Message{ return c.CreateChatCompletion(ctx, params, messages, callback)
Role: api.MessageRoleAssistant, }
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}, nil })
}
return content, nil
} }
func (c *Client) CreateChatCompletionStream( func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (*api.Message, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req, err := createGenerateContentRequest(params, messages) req, err := createGenerateContentRequest(params, messages)
if err != nil { if err != nil {
return nil, err return "", err
} }
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
url := fmt.Sprintf( url := fmt.Sprintf(
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
c.BaseURL, params.Model, c.APIKey, c.BaseURL, params.Model, c.APIKey,
) )
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -322,7 +374,7 @@ func (c *Client) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return nil, err return "", err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -335,7 +387,7 @@ func (c *Client) CreateChatCompletionStream(
var resp GenerateContentResponse var resp GenerateContentResponse
err = json.Unmarshal(line, &resp) err = json.Unmarshal(line, &resp)
if err != nil { if err != nil {
return nil, err return "", err
} }
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
@ -357,15 +409,21 @@ func (c *Client) CreateChatCompletionStream(
// If there are function calls, handle them and recurse // If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ messages, err := handleToolCalls(
Role: api.MessageRoleToolCall, params, content.String(), convertToolCallToAPI(toolCalls), callback, messages,
Content: content.String(), )
ToolCalls: convertToolCallToAPI(toolCalls), if err != nil {
}, nil return content.String(), err
}
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
return &api.Message{ if callback != nil {
Role: api.MessageRoleAssistant, callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}, nil })
}
return content.String(), nil
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type OllamaClient struct { type OllamaClient struct {
@ -42,8 +43,8 @@ type OllamaResponse struct {
} }
func createOllamaRequest( func createOllamaRequest(
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) OllamaRequest { ) OllamaRequest {
requestMessages := make([]OllamaMessage, 0, len(messages)) requestMessages := make([]OllamaMessage, 0, len(messages))
@ -63,11 +64,11 @@ func createOllamaRequest(
return request return request
} }
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) { func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req.WithContext(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,11 +83,12 @@ func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
func (c *OllamaClient) CreateChatCompletion( func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) (*api.Message, error) { callback api.ReplyCallback,
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createOllamaRequest(params, messages) req := createOllamaRequest(params, messages)
@ -94,40 +96,46 @@ func (c *OllamaClient) CreateChatCompletion(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp OllamaResponse var completionResp OllamaResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return nil, err return "", err
} }
return &api.Message{ content := completionResp.Message.Content
Role: api.MessageRoleAssistant, if callback != nil {
Content: completionResp.Message.Content, callback(model.Message{
}, nil Role: model.MessageRoleAssistant,
Content: content,
})
}
return content, nil
} }
func (c *OllamaClient) CreateChatCompletionStream( func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (*api.Message, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createOllamaRequest(params, messages) req := createOllamaRequest(params, messages)
@ -135,17 +143,17 @@ func (c *OllamaClient) CreateChatCompletionStream(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -158,7 +166,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return nil, err return "", err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -169,7 +177,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
var streamResp OllamaResponse var streamResp OllamaResponse
err = json.Unmarshal(line, &streamResp) err = json.Unmarshal(line, &streamResp)
if err != nil { if err != nil {
return nil, err return "", err
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
@ -181,8 +189,12 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
} }
return &api.Message{ if callback != nil {
Role: api.MessageRoleAssistant, callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}, nil })
}
return content.String(), nil
} }

View File

@ -11,9 +11,11 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
func convertTools(tools []api.ToolSpec) []Tool { func convertTools(tools []model.Tool) []Tool {
openaiTools := make([]Tool, len(tools)) openaiTools := make([]Tool, len(tools))
for i, tool := range tools { for i, tool := range tools {
openaiTools[i].Type = "function" openaiTools[i].Type = "function"
@ -45,7 +47,7 @@ func convertTools(tools []api.ToolSpec) []Tool {
return openaiTools return openaiTools
} }
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall { func convertToolCallToOpenAI(toolCalls []model.ToolCall) []ToolCall {
converted := make([]ToolCall, len(toolCalls)) converted := make([]ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].Type = "function" converted[i].Type = "function"
@ -58,8 +60,8 @@ func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
return converted return converted
} }
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall { func convertToolCallToAPI(toolCalls []ToolCall) []model.ToolCall {
converted := make([]api.ToolCall, len(toolCalls)) converted := make([]model.ToolCall, len(toolCalls))
for i, call := range toolCalls { for i, call := range toolCalls {
converted[i].ID = call.ID converted[i].ID = call.ID
converted[i].Name = call.Function.Name converted[i].Name = call.Function.Name
@ -69,8 +71,8 @@ func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
} }
func createChatCompletionRequest( func createChatCompletionRequest(
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) ChatCompletionRequest { ) ChatCompletionRequest {
requestMessages := make([]ChatCompletionMessage, 0, len(messages)) requestMessages := make([]ChatCompletionMessage, 0, len(messages))
@ -115,15 +117,56 @@ func createChatCompletionRequest(
return request return request
} }
func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) { func handleToolCalls(
params model.RequestParameters,
content string,
toolCalls []ToolCall,
callback api.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{
Role: model.MessageRoleToolCall,
Content: content,
ToolCalls: convertToolCallToAPI(toolCalls),
}
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
if err != nil {
return nil, err
}
toolResult := model.Message{
Role: model.MessageRoleToolResult,
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
}
func (c *OpenAIClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey) req.Header.Set("Authorization", "Bearer "+c.APIKey)
client := &http.Client{} client := &http.Client{}
resp, err := client.Do(req) resp, err := client.Do(req.WithContext(ctx))
if err != nil {
return nil, err
}
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
bytes, _ := io.ReadAll(resp.Body) bytes, _ := io.ReadAll(resp.Body)
@ -135,34 +178,35 @@ func (c *OpenAIClient) sendRequest(req *http.Request) (*http.Response, error) {
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
) (*api.Message, error) { callback api.ReplyCallback,
) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createChatCompletionRequest(params, messages) req := createChatCompletionRequest(params, messages)
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
var completionResp ChatCompletionResponse var completionResp ChatCompletionResponse
err = json.NewDecoder(resp.Body).Decode(&completionResp) err = json.NewDecoder(resp.Body).Decode(&completionResp)
if err != nil { if err != nil {
return nil, err return "", err
} }
choice := completionResp.Choices[0] choice := completionResp.Choices[0]
@ -177,27 +221,34 @@ func (c *OpenAIClient) CreateChatCompletion(
toolCalls := choice.Message.ToolCalls toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ messages, err := handleToolCalls(params, content, toolCalls, callback, messages)
Role: api.MessageRoleToolCall, if err != nil {
Content: content, return content, err
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ return c.CreateChatCompletion(ctx, params, messages, callback)
Role: api.MessageRoleAssistant, }
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content, Content: content,
}, nil })
}
// Return the user-facing message.
return content, nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params api.RequestParameters, params model.RequestParameters,
messages []api.Message, messages []model.Message,
callback api.ReplyCallback,
output chan<- api.Chunk, output chan<- api.Chunk,
) (*api.Message, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
} }
req := createChatCompletionRequest(params, messages) req := createChatCompletionRequest(params, messages)
@ -205,17 +256,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return "", err
} }
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return "", err
} }
resp, err := c.sendRequest(httpReq) resp, err := c.sendRequest(ctx, httpReq)
if err != nil { if err != nil {
return nil, err return "", err
} }
defer resp.Body.Close() defer resp.Body.Close()
@ -234,7 +285,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err == io.EOF { if err == io.EOF {
break break
} }
return nil, err return "", err
} }
line = bytes.TrimSpace(line) line = bytes.TrimSpace(line)
@ -250,7 +301,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
var streamResp ChatCompletionStreamResponse var streamResp ChatCompletionStreamResponse
err = json.Unmarshal(line, &streamResp) err = json.Unmarshal(line, &streamResp)
if err != nil { if err != nil {
return nil, err return "", err
} }
delta := streamResp.Choices[0].Delta delta := streamResp.Choices[0].Delta
@ -258,7 +309,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
// Construct streamed tool_call arguments // Construct streamed tool_call arguments
for _, tc := range delta.ToolCalls { for _, tc := range delta.ToolCalls {
if tc.Index == nil { if tc.Index == nil {
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.") return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
} }
if len(toolCalls) <= *tc.Index { if len(toolCalls) <= *tc.Index {
toolCalls = append(toolCalls, tc) toolCalls = append(toolCalls, tc)
@ -277,15 +328,21 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
return &api.Message{ messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages)
Role: api.MessageRoleToolCall, if err != nil {
Content: content.String(), return content.String(), err
ToolCalls: convertToolCallToAPI(toolCalls),
}, nil
} }
return &api.Message{ // Recurse into CreateChatCompletionStream with the tool call replies
Role: api.MessageRoleAssistant, return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} else {
if callback != nil {
callback(model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}, nil })
}
}
return content.String(), nil
} }

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -36,7 +36,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
lastMessage := &messages[len(messages)-1] lastMessage := &messages[len(messages)-1]
if lastMessage.Role != api.MessageRoleAssistant { if lastMessage.Role != model.MessageRoleAssistant {
return fmt.Errorf("the last message in the conversation is not an assistant message") return fmt.Errorf("the last message in the conversation is not an assistant message")
} }
@ -50,7 +50,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -53,10 +53,10 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
role, _ := cmd.Flags().GetString("role") role, _ := cmd.Flags().GetString("role")
if role != "" { if role != "" {
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) { if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.") return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
} }
toEdit.Role = api.MessageRole(role) toEdit.Role = model.MessageRole(role)
} }
// Update the message in-place // Update the message in-place

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -20,19 +20,19 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []api.Message var messages []model.Message
// TODO: probably just make this part of the conversation // TODO: probably just make this part of the conversation
system := ctx.GetSystemPrompt() system := ctx.GetSystemPrompt()
if system != "" { if system != "" {
messages = append(messages, api.Message{ messages = append(messages, model.Message{
Role: api.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: system, Content: system,
}) })
} }
messages = append(messages, api.Message{ messages = append(messages, model.Message{
Role: api.MessageRoleUser, Role: model.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -20,19 +20,19 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []api.Message var messages []model.Message
// TODO: stop supplying system prompt as a message // TODO: stop supplying system prompt as a message
system := ctx.GetSystemPrompt() system := ctx.GetSystemPrompt()
if system != "" { if system != "" {
messages = append(messages, api.Message{ messages = append(messages, model.Message{
Role: api.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: system, Content: system,
}) })
} }
messages = append(messages, api.Message{ messages = append(messages, model.Message{
Role: api.MessageRoleUser, Role: model.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -23,7 +23,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
return nil return nil
}, },
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
var toRemove []*api.Conversation var toRemove []*model.Conversation
for _, shortName := range args { for _, shortName := range args {
conversation := cmdutil.LookupConversation(ctx, shortName) conversation := cmdutil.LookupConversation(ctx, shortName)
toRemove = append(toRemove, conversation) toRemove = append(toRemove, conversation)

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -30,8 +30,8 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No reply was provided.") return fmt.Errorf("No reply was provided.")
} }
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{ cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
Role: api.MessageRoleUser, Role: model.MessageRoleUser,
Content: reply, Content: reply,
}) })
return nil return nil

View File

@ -3,9 +3,9 @@ package cmd
import ( import (
"fmt" "fmt"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -43,11 +43,11 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
retryFromIdx := len(messages) - 1 - offset retryFromIdx := len(messages) - 1 - offset
// decrease retryFromIdx until we hit a user message // decrease retryFromIdx until we hit a user message
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser { for retryFromIdx >= 0 && messages[retryFromIdx].Role != model.MessageRoleUser {
retryFromIdx-- retryFromIdx--
} }
if messages[retryFromIdx].Role != api.MessageRoleUser { if messages[retryFromIdx].Role != model.MessageRoleUser {
return fmt.Errorf("No user messages to retry") return fmt.Errorf("No user messages to retry")
} }

View File

@ -10,36 +10,36 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan api.Chunk) // receives the reponse from LLM
defer close(content)
// render all content received over the channel
go ShowDelayedContent(content)
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model) m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return nil, err return "", err
} }
requestParams := api.RequestParameters{ requestParams := model.RequestParameters{
Model: m, Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
content := make(chan api.Chunk) response, err := provider.CreateChatCompletionStream(
defer close(content) context.Background(), requestParams, messages, callback, content,
// render the content received over the channel
go ShowDelayedContent(content)
reply, err := provider.CreateChatCompletionStream(
context.Background(), requestParams, messages, content,
) )
if response != "" {
if reply.Content != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
fmt.Println() fmt.Println()
@ -48,12 +48,12 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
err = nil err = nil
} }
} }
return reply, err return response, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
// short name or exits the program // short name or exits the program
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation { func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
lmcli.Fatal("Could not lookup conversation: %v\n", err) lmcli.Fatal("Could not lookup conversation: %v\n", err)
@ -64,7 +64,7 @@ func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation
return c return c
} }
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) { func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
c, err := ctx.Store.ConversationByShortName(shortName) c, err := ctx.Store.ConversationByShortName(shortName)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not lookup conversation: %v", err) return nil, fmt.Errorf("Could not lookup conversation: %v", err)
@ -75,7 +75,7 @@ func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversatio
return c, nil return c, nil
} }
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) { func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot) messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
if err != nil { if err != nil {
lmcli.Fatal("Could not load messages: %v\n", err) lmcli.Fatal("Could not load messages: %v\n", err)
@ -85,7 +85,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bo
// handleConversationReply handles sending messages to an existing // handleConversationReply handles sending messages to an existing
// conversation, optionally persisting both the sent replies and responses. // conversation, optionally persisting both the sent replies and responses.
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) { func HandleReply(ctx *lmcli.Context, to *model.Message, persist bool, messages ...model.Message) {
if to == nil { if to == nil {
lmcli.Fatal("Can't prompt from an empty message.") lmcli.Fatal("Can't prompt from an empty message.")
} }
@ -97,7 +97,7 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...
RenderConversation(ctx, append(existing, messages...), true) RenderConversation(ctx, append(existing, messages...), true)
var savedReplies []api.Message var savedReplies []model.Message
if persist && len(messages) > 0 { if persist && len(messages) > 0 {
savedReplies, err = ctx.Store.Reply(to, messages...) savedReplies, err = ctx.Store.Reply(to, messages...)
if err != nil { if err != nil {
@ -106,15 +106,15 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...
} }
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
var lastSavedMessage *api.Message var lastSavedMessage *model.Message
lastSavedMessage = to lastSavedMessage = to
if len(savedReplies) > 0 { if len(savedReplies) > 0 {
lastSavedMessage = &savedReplies[len(savedReplies)-1] lastSavedMessage = &savedReplies[len(savedReplies)-1]
} }
replyCallback := func(reply api.Message) { replyCallback := func(reply model.Message) {
if !persist { if !persist {
return return
} }
@ -131,16 +131,16 @@ func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...
} }
} }
func FormatForExternalPrompt(messages []api.Message, system bool) string { func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{} sb := strings.Builder{}
for _, message := range messages { for _, message := range messages {
if message.Content == "" { if message.Content == "" {
continue continue
} }
switch message.Role { switch message.Role {
case api.MessageRoleAssistant, api.MessageRoleToolCall: case model.MessageRoleAssistant, model.MessageRoleToolCall:
sb.WriteString("Assistant:\n\n") sb.WriteString("Assistant:\n\n")
case api.MessageRoleUser: case model.MessageRoleUser:
sb.WriteString("User:\n\n") sb.WriteString("User:\n\n")
default: default:
continue continue
@ -150,7 +150,7 @@ func FormatForExternalPrompt(messages []api.Message, system bool) string {
return sb.String() return sb.String()
} }
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) { func GenerateTitle(ctx *lmcli.Context, messages []model.Message) (string, error) {
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below. const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
Example conversation: Example conversation:
@ -177,32 +177,28 @@ Example response:
return "", err return "", err
} }
generateRequest := []api.Message{ generateRequest := []model.Message{
{ {
Role: api.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: systemPrompt, Content: systemPrompt,
}, },
{ {
Role: api.MessageRoleUser, Role: model.MessageRoleUser,
Content: string(conversation), Content: string(conversation),
}, },
} }
m, provider, err := ctx.GetModelProvider( m, provider, err := ctx.GetModelProvider(*ctx.Config.Conversations.TitleGenerationModel)
*ctx.Config.Conversations.TitleGenerationModel,
)
if err != nil { if err != nil {
return "", err return "", err
} }
requestParams := api.RequestParameters{ requestParams := model.RequestParameters{
Model: m, Model: m,
MaxTokens: 25, MaxTokens: 25,
} }
response, err := provider.CreateChatCompletion( response, err := provider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
context.Background(), requestParams, generateRequest,
)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -211,7 +207,7 @@ Example response:
var jsonResponse struct { var jsonResponse struct {
Title string `json:"title"` Title string `json:"title"`
} }
err = json.Unmarshal([]byte(response.Content), &jsonResponse) err = json.Unmarshal([]byte(response), &jsonResponse)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -276,7 +272,7 @@ func ShowDelayedContent(content <-chan api.Chunk) {
// RenderConversation renders the given messages to TTY, with optional space // RenderConversation renders the given messages to TTY, with optional space
// for a subsequent message. spaceForResponse controls how many '\n' characters // for a subsequent message. spaceForResponse controls how many '\n' characters
// are printed immediately after the final message (1 if false, 2 if true) // are printed immediately after the final message (1 if false, 2 if true)
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) { func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
l := len(messages) l := len(messages)
for i, message := range messages { for i, message := range messages {
RenderMessage(ctx, &message) RenderMessage(ctx, &message)
@ -287,7 +283,7 @@ func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResp
} }
} }
func RenderMessage(ctx *lmcli.Context, m *api.Message) { func RenderMessage(ctx *lmcli.Context, m *model.Message) {
var messageAge string var messageAge string
if m.CreatedAt.IsZero() { if m.CreatedAt.IsZero() {
messageAge = "now" messageAge = "now"
@ -299,11 +295,11 @@ func RenderMessage(ctx *lmcli.Context, m *api.Message) {
headerStyle := lipgloss.NewStyle().Bold(true) headerStyle := lipgloss.NewStyle().Bold(true)
switch m.Role { switch m.Role {
case api.MessageRoleSystem: case model.MessageRoleSystem:
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
case api.MessageRoleUser: case model.MessageRoleUser:
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
case api.MessageRoleAssistant: case model.MessageRoleAssistant:
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
} }

View File

@ -6,12 +6,13 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/agent"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@ -23,7 +24,7 @@ type Context struct {
Store ConversationStore Store ConversationStore
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
EnabledTools []api.ToolSpec EnabledTools []model.Tool
SystemPromptFile string SystemPromptFile string
} }
@ -49,9 +50,9 @@ func NewContext() (*Context, error) {
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style) chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
var enabledTools []api.ToolSpec var enabledTools []model.Tool
for _, toolName := range config.Tools.EnabledTools { for _, toolName := range config.Tools.EnabledTools {
tool, ok := agent.AvailableTools[toolName] tool, ok := tools.AvailableTools[toolName]
if ok { if ok {
enabledTools = append(enabledTools, tool) enabledTools = append(enabledTools, tool)
} }
@ -78,7 +79,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")
var provider string var provider string

View File

@ -1,6 +1,7 @@
package api package model
import ( import (
"database/sql"
"time" "time"
) )
@ -31,6 +32,24 @@ type Message struct {
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
type Conversation struct {
ID uint `gorm:"primaryKey"`
ShortName sql.NullString
Title string
SelectedRootID *uint
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
}
type RequestParameters struct {
Model string
MaxTokens int
Temperature float32
TopP float32
ToolBag []Tool
}
func (m *MessageRole) IsAssistant() bool { func (m *MessageRole) IsAssistant() bool {
switch *m { switch *m {
case MessageRoleAssistant, MessageRoleToolCall: case MessageRoleAssistant, MessageRoleToolCall:

View File

@ -1,4 +1,4 @@
package api package model
import ( import (
"database/sql/driver" "database/sql/driver"
@ -6,11 +6,11 @@ import (
"fmt" "fmt"
) )
type ToolSpec struct { type Tool struct {
Name string Name string
Description string Description string
Parameters []ToolParameter Parameters []ToolParameter
Impl func(*ToolSpec, map[string]interface{}) (string, error) Impl func(*Tool, map[string]interface{}) (string, error)
} }
type ToolParameter struct { type ToolParameter struct {
@ -27,12 +27,6 @@ type ToolCall struct {
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"` Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
} }
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolCalls []ToolCall type ToolCalls []ToolCall
func (tc *ToolCalls) Scan(value any) (err error) { func (tc *ToolCalls) Scan(value any) (err error) {
@ -56,6 +50,12 @@ func (tc ToolCalls) Value() (driver.Value, error) {
return string(jsonBytes), nil return string(jsonBytes), nil
} }
type ToolResult struct {
ToolCallID string `json:"toolCallID" yaml:"-"`
ToolName string `json:"toolName,omitempty" yaml:"tool"`
Result string `json:"result,omitempty" yaml:"result"`
}
type ToolResults []ToolResult type ToolResults []ToolResult
func (tr *ToolResults) Scan(value any) (err error) { func (tr *ToolResults) Scan(value any) (err error) {

View File

@ -8,32 +8,32 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
sqids "github.com/sqids/sqids-go" sqids "github.com/sqids/sqids-go"
"gorm.io/gorm" "gorm.io/gorm"
) )
type ConversationStore interface { type ConversationStore interface {
ConversationByShortName(shortName string) (*api.Conversation, error) ConversationByShortName(shortName string) (*model.Conversation, error)
ConversationShortNameCompletions(search string) []string ConversationShortNameCompletions(search string) []string
RootMessages(conversationID uint) ([]api.Message, error) RootMessages(conversationID uint) ([]model.Message, error)
LatestConversationMessages() ([]api.Message, error) LatestConversationMessages() ([]model.Message, error)
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error)
UpdateConversation(conversation *api.Conversation) error UpdateConversation(conversation *model.Conversation) error
DeleteConversation(conversation *api.Conversation) error DeleteConversation(conversation *model.Conversation) error
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error)
MessageByID(messageID uint) (*api.Message, error) MessageByID(messageID uint) (*model.Message, error)
MessageReplies(messageID uint) ([]api.Message, error) MessageReplies(messageID uint) ([]model.Message, error)
UpdateMessage(message *api.Message) error UpdateMessage(message *model.Message) error
DeleteMessage(message *api.Message, prune bool) error DeleteMessage(message *model.Message, prune bool) error
CloneBranch(toClone api.Message) (*api.Message, uint, error) CloneBranch(toClone model.Message) (*model.Message, uint, error)
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error)
PathToRoot(message *api.Message) ([]api.Message, error) PathToRoot(message *model.Message) ([]model.Message, error)
PathToLeaf(message *api.Message) ([]api.Message, error) PathToLeaf(message *model.Message) ([]model.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -43,8 +43,8 @@ type SQLStore struct {
func NewSQLStore(db *gorm.DB) (*SQLStore, error) { func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
models := []any{ models := []any{
&api.Conversation{}, &model.Conversation{},
&api.Message{}, &model.Message{},
} }
for _, x := range models { for _, x := range models {
@ -58,9 +58,9 @@ func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
return &SQLStore{db, _sqids}, nil return &SQLStore{db, _sqids}, nil
} }
func (s *SQLStore) createConversation() (*api.Conversation, error) { func (s *SQLStore) createConversation() (*model.Conversation, error) {
// Create the new conversation // Create the new conversation
c := &api.Conversation{} c := &model.Conversation{}
err := s.db.Save(c).Error err := s.db.Save(c).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -75,28 +75,28 @@ func (s *SQLStore) createConversation() (*api.Conversation, error) {
return c, nil return c, nil
} }
func (s *SQLStore) UpdateConversation(c *api.Conversation) error { func (s *SQLStore) UpdateConversation(c *model.Conversation) error {
if c == nil || c.ID == 0 { if c == nil || c.ID == 0 {
return fmt.Errorf("Conversation is nil or invalid (missing ID)") return fmt.Errorf("Conversation is nil or invalid (missing ID)")
} }
return s.db.Updates(c).Error return s.db.Updates(c).Error
} }
func (s *SQLStore) DeleteConversation(c *api.Conversation) error { func (s *SQLStore) DeleteConversation(c *model.Conversation) error {
// Delete messages first // Delete messages first
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error err := s.db.Where("conversation_id = ?", c.ID).Delete(&model.Message{}).Error
if err != nil { if err != nil {
return err return err
} }
return s.db.Delete(c).Error return s.db.Delete(c).Error
} }
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error { func (s *SQLStore) DeleteMessage(message *model.Message, prune bool) error {
panic("Not yet implemented") panic("Not yet implemented")
//return s.db.Delete(&message).Error //return s.db.Delete(&message).Error
} }
func (s *SQLStore) UpdateMessage(m *api.Message) error { func (s *SQLStore) UpdateMessage(m *model.Message) error {
if m == nil || m.ID == 0 { if m == nil || m.ID == 0 {
return fmt.Errorf("Message is nil or invalid (missing ID)") return fmt.Errorf("Message is nil or invalid (missing ID)")
} }
@ -104,7 +104,7 @@ func (s *SQLStore) UpdateMessage(m *api.Message) error {
} }
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string { func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
var conversations []api.Conversation var conversations []model.Conversation
// ignore error for completions // ignore error for completions
s.db.Find(&conversations) s.db.Find(&conversations)
completions := make([]string, 0, len(conversations)) completions := make([]string, 0, len(conversations))
@ -116,17 +116,17 @@ func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
return completions return completions
} }
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) { func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
if shortName == "" { if shortName == "" {
return nil, errors.New("shortName is empty") return nil, errors.New("shortName is empty")
} }
var conversation api.Conversation var conversation model.Conversation
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
return &conversation, err return &conversation, err
} }
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) { func (s *SQLStore) RootMessages(conversationID uint) ([]model.Message, error) {
var rootMessages []api.Message var rootMessages []model.Message
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
if err != nil { if err != nil {
return nil, err return nil, err
@ -134,20 +134,20 @@ func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
return rootMessages, nil return rootMessages, nil
} }
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) { func (s *SQLStore) MessageByID(messageID uint) (*model.Message, error) {
var message api.Message var message model.Message
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
return &message, err return &message, err
} }
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) { func (s *SQLStore) MessageReplies(messageID uint) ([]model.Message, error) {
var replies []api.Message var replies []model.Message
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
return replies, err return replies, err
} }
// StartConversation starts a new conversation with the provided messages // StartConversation starts a new conversation with the provided messages
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) { func (s *SQLStore) StartConversation(messages ...model.Message) (*model.Conversation, []model.Message, error) {
if len(messages) == 0 { if len(messages) == 0 {
return nil, nil, fmt.Errorf("Must provide at least 1 message") return nil, nil, fmt.Errorf("Must provide at least 1 message")
} }
@ -178,13 +178,13 @@ func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
messages = append([]api.Message{messages[0]}, newMessages...) messages = append([]model.Message{messages[0]}, newMessages...)
} }
return conversation, messages, nil return conversation, messages, nil
} }
// CloneConversation clones the given conversation and all of its root meesages // CloneConversation clones the given conversation and all of its root meesages
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) { func (s *SQLStore) CloneConversation(toClone model.Conversation) (*model.Conversation, uint, error) {
rootMessages, err := s.RootMessages(toClone.ID) rootMessages, err := s.RootMessages(toClone.ID)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
@ -226,8 +226,8 @@ func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversatio
} }
// Reply to a message with a series of messages (each following the next) // Reply to a message with a series of messages (each following the next)
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) { func (s *SQLStore) Reply(to *model.Message, messages ...model.Message) ([]model.Message, error) {
var savedMessages []api.Message var savedMessages []model.Message
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
currentParent := to currentParent := to
@ -262,7 +262,7 @@ func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Messag
// CloneBranch returns a deep clone of the given message and its replies, returning // CloneBranch returns a deep clone of the given message and its replies, returning
// a new message object. The new message will be attached to the same parent as // a new message object. The new message will be attached to the same parent as
// the messageToClone // the messageToClone
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) { func (s *SQLStore) CloneBranch(messageToClone model.Message) (*model.Message, uint, error) {
newMessage := messageToClone newMessage := messageToClone
newMessage.ID = 0 newMessage.ID = 0
newMessage.Replies = nil newMessage.Replies = nil
@ -304,19 +304,19 @@ func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint,
return &newMessage, replyCount, nil return &newMessage, replyCount, nil
} }
func fetchMessages(db *gorm.DB) ([]api.Message, error) { func fetchMessages(db *gorm.DB) ([]model.Message, error) {
var messages []api.Message var messages []model.Message
if err := db.Preload("Conversation").Find(&messages).Error; err != nil { if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
return nil, fmt.Errorf("Could not fetch messages: %v", err) return nil, fmt.Errorf("Could not fetch messages: %v", err)
} }
messageMap := make(map[uint]api.Message) messageMap := make(map[uint]model.Message)
for i, message := range messages { for i, message := range messages {
messageMap[messages[i].ID] = message messageMap[messages[i].ID] = message
} }
// Create a map to store replies by their parent ID // Create a map to store replies by their parent ID
repliesMap := make(map[uint][]api.Message) repliesMap := make(map[uint][]model.Message)
for i, message := range messages { for i, message := range messages {
if messages[i].ParentID != nil { if messages[i].ParentID != nil {
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message) repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
@ -326,7 +326,7 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
// Assign replies, parent, and selected reply to each message // Assign replies, parent, and selected reply to each message
for i := range messages { for i := range messages {
if replies, exists := repliesMap[messages[i].ID]; exists { if replies, exists := repliesMap[messages[i].ID]; exists {
messages[i].Replies = make([]api.Message, len(replies)) messages[i].Replies = make([]model.Message, len(replies))
for j, m := range replies { for j, m := range replies {
messages[i].Replies[j] = m messages[i].Replies[j] = m
} }
@ -345,21 +345,21 @@ func fetchMessages(db *gorm.DB) ([]api.Message, error) {
return messages, nil return messages, nil
} }
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) { func (s *SQLStore) buildPath(message *model.Message, getNext func(*model.Message) *uint) ([]model.Message, error) {
var messages []api.Message var messages []model.Message
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID)) messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Create a map to store messages by their ID // Create a map to store messages by their ID
messageMap := make(map[uint]*api.Message) messageMap := make(map[uint]*model.Message)
for i := range messages { for i := range messages {
messageMap[messages[i].ID] = &messages[i] messageMap[messages[i].ID] = &messages[i]
} }
// Build the path // Build the path
var path []api.Message var path []model.Message
nextID := &message.ID nextID := &message.ID
for { for {
@ -382,12 +382,12 @@ func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *u
// PathToRoot traverses the provided message's Parent until reaching the tree // PathToRoot traverses the provided message's Parent until reaching the tree
// root and returns a slice of all messages traversed in chronological order // root and returns a slice of all messages traversed in chronological order
// (starting with the root and ending with the message provided) // (starting with the root and ending with the message provided)
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) { func (s *SQLStore) PathToRoot(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
path, err := s.buildPath(message, func(m *api.Message) *uint { path, err := s.buildPath(message, func(m *model.Message) *uint {
return m.ParentID return m.ParentID
}) })
if err != nil { if err != nil {
@ -401,24 +401,24 @@ func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
// PathToLeaf traverses the provided message's SelectedReply until reaching a // PathToLeaf traverses the provided message's SelectedReply until reaching a
// tree leaf and returns a slice of all messages traversed in chronological // tree leaf and returns a slice of all messages traversed in chronological
// order (starting with the message provided and ending with the leaf) // order (starting with the message provided and ending with the leaf)
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) { func (s *SQLStore) PathToLeaf(message *model.Message) ([]model.Message, error) {
if message == nil || message.ID <= 0 { if message == nil || message.ID <= 0 {
return nil, fmt.Errorf("Message is nil or has invalid ID") return nil, fmt.Errorf("Message is nil or has invalid ID")
} }
return s.buildPath(message, func(m *api.Message) *uint { return s.buildPath(message, func(m *model.Message) *uint {
return m.SelectedReplyID return m.SelectedReplyID
}) })
} }
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) { func (s *SQLStore) LatestConversationMessages() ([]model.Message, error) {
var latestMessages []api.Message var latestMessages []model.Message
subQuery := s.db.Model(&api.Message{}). subQuery := s.db.Model(&model.Message{}).
Select("MAX(created_at) as max_created_at, conversation_id"). Select("MAX(created_at) as max_created_at, conversation_id").
Group("conversation_id") Group("conversation_id")
err := s.db.Model(&api.Message{}). err := s.db.Model(&model.Message{}).
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery). Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
Group("messages.conversation_id"). Group("messages.conversation_id").
Order("created_at DESC"). Order("created_at DESC").

View File

@ -1,4 +1,4 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
@ -7,8 +7,8 @@ import (
"strconv" "strconv"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents. const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
@ -27,10 +27,10 @@ Example result:
} }
` `
var DirTreeTool = api.ToolSpec{ var DirTreeTool = model.Tool{
Name: "dir_tree", Name: "dir_tree",
Description: TREE_DESCRIPTION, Description: TREE_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "relative_path", Name: "relative_path",
Type: "string", Type: "string",
@ -42,7 +42,7 @@ var DirTreeTool = api.ToolSpec{
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.", Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
if tmp, ok := args["relative_path"]; ok { if tmp, ok := args["relative_path"]; ok {
relativeDir, ok = tmp.(string) relativeDir, ok = tmp.(string)
@ -76,25 +76,25 @@ var DirTreeTool = api.ToolSpec{
}, },
} }
func tree(path string, depth int) api.CallResult { func tree(path string, depth int) model.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
var treeOutput strings.Builder var treeOutput strings.Builder
treeOutput.WriteString(path + "\n") treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", depth) err := buildTree(&treeOutput, path, "", depth)
if err != nil { if err != nil {
return api.CallResult{ return model.CallResult{
Message: err.Error(), Message: err.Error(),
} }
} }
return api.CallResult{Result: treeOutput.String()} return model.CallResult{Result: treeOutput.String()}
} }
func buildTree(output *strings.Builder, path string, prefix string, depth int) error { func buildTree(output *strings.Builder, path string, prefix string, depth int) error {

View File

@ -1,22 +1,22 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path. const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
Make sure your inserts match the flow and indentation of surrounding content.` Make sure your inserts match the flow and indentation of surrounding content.`
var FileInsertLinesTool = api.ToolSpec{ var FileInsertLinesTool = model.Tool{
Name: "file_insert_lines", Name: "file_insert_lines",
Description: FILE_INSERT_LINES_DESCRIPTION, Description: FILE_INSERT_LINES_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -36,7 +36,7 @@ var FileInsertLinesTool = api.ToolSpec{
Required: true, Required: true,
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.") return "", fmt.Errorf("path parameter to write_file was not included.")
@ -72,27 +72,27 @@ var FileInsertLinesTool = api.ToolSpec{
}, },
} }
func fileInsertLines(path string, position int, content string) api.CallResult { func fileInsertLines(path string, position int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
// Read the existing file's content // Read the existing file's content
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
_, err = os.Create(path) _, err = os.Create(path)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
} }
data = []byte{} data = []byte{}
} }
if position < 1 { if position < 1 {
return api.CallResult{Message: "start_line cannot be less than 1"} return model.CallResult{Message: "start_line cannot be less than 1"}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) api.CallResult {
// Join the lines and write back to the file // Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644) err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return api.CallResult{Result: newContent} return model.CallResult{Result: newContent}
} }

View File

@ -1,12 +1,12 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path. const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
@ -15,10 +15,10 @@ Useful for re-writing snippets/blocks of code or entire functions.
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.` Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
var FileReplaceLinesTool = api.ToolSpec{ var FileReplaceLinesTool = model.Tool{
Name: "file_replace_lines", Name: "file_replace_lines",
Description: FILE_REPLACE_LINES_DESCRIPTION, Description: FILE_REPLACE_LINES_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -42,7 +42,7 @@ var FileReplaceLinesTool = api.ToolSpec{
Description: `Content to replace specified range. Omit to remove the specified range.`, Description: `Content to replace specified range. Omit to remove the specified range.`,
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("path parameter to write_file was not included.") return "", fmt.Errorf("path parameter to write_file was not included.")
@ -87,27 +87,27 @@ var FileReplaceLinesTool = api.ToolSpec{
}, },
} }
func fileReplaceLines(path string, startLine int, endLine int, content string) api.CallResult { func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
// Read the existing file's content // Read the existing file's content
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
_, err = os.Create(path) _, err = os.Create(path)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
} }
data = []byte{} data = []byte{}
} }
if startLine < 1 { if startLine < 1 {
return api.CallResult{Message: "start_line cannot be less than 1"} return model.CallResult{Message: "start_line cannot be less than 1"}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -126,8 +126,8 @@ func fileReplaceLines(path string, startLine int, endLine int, content string) a
// Join the lines and write back to the file // Join the lines and write back to the file
err = os.WriteFile(path, []byte(newContent), 0644) err = os.WriteFile(path, []byte(newContent), 0644)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return api.CallResult{Result: newContent} return model.CallResult{Result: newContent}
} }

View File

@ -1,4 +1,4 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
@ -6,8 +6,8 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory). const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
@ -25,17 +25,17 @@ Example result:
For files, size represents the size of the file, in bytes. For files, size represents the size of the file, in bytes.
For directories, size represents the number of entries in that directory.` For directories, size represents the number of entries in that directory.`
var ReadDirTool = api.ToolSpec{ var ReadDirTool = model.Tool{
Name: "read_dir", Name: "read_dir",
Description: READ_DIR_DESCRIPTION, Description: READ_DIR_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "relative_dir", Name: "relative_dir",
Type: "string", Type: "string",
Description: "If set, read the contents of a directory relative to the current one.", Description: "If set, read the contents of a directory relative to the current one.",
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string var relativeDir string
tmp, ok := args["relative_dir"] tmp, ok := args["relative_dir"]
if ok { if ok {
@ -53,18 +53,18 @@ var ReadDirTool = api.ToolSpec{
}, },
} }
func readDir(path string) api.CallResult { func readDir(path string) model.CallResult {
if path == "" { if path == "" {
path = "." path = "."
} }
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
files, err := os.ReadDir(path) files, err := os.ReadDir(path)
if err != nil { if err != nil {
return api.CallResult{ return model.CallResult{
Message: err.Error(), Message: err.Error(),
} }
} }
@ -96,5 +96,5 @@ func readDir(path string) api.CallResult {
}) })
} }
return api.CallResult{Result: dirContents} return model.CallResult{Result: dirContents}
} }

View File

@ -1,12 +1,12 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory. const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
@ -21,10 +21,10 @@ Example result:
"result": "1\tthe contents\n2\tof the file\n" "result": "1\tthe contents\n2\tof the file\n"
}` }`
var ReadFileTool = api.ToolSpec{ var ReadFileTool = model.Tool{
Name: "read_file", Name: "read_file",
Description: READ_FILE_DESCRIPTION, Description: READ_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -33,7 +33,7 @@ var ReadFileTool = api.ToolSpec{
}, },
}, },
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to read_file was not included.") return "", fmt.Errorf("Path parameter to read_file was not included.")
@ -51,14 +51,14 @@ var ReadFileTool = api.ToolSpec{
}, },
} }
func readFile(path string) api.CallResult { func readFile(path string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
} }
lines := strings.Split(string(data), "\n") lines := strings.Split(string(data), "\n")
@ -67,7 +67,7 @@ func readFile(path string) api.CallResult {
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line)) content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
} }
return api.CallResult{ return model.CallResult{
Result: content.String(), Result: content.String(),
} }
} }

48
pkg/lmcli/tools/tools.go Normal file
View File

@ -0,0 +1,48 @@
package tools
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
var AvailableTools map[string]model.Tool = map[string]model.Tool{
"dir_tree": DirTreeTool,
"read_dir": ReadDirTool,
"read_file": ReadFileTool,
"write_file": WriteFileTool,
"file_insert_lines": FileInsertLinesTool,
"file_replace_lines": FileReplaceLinesTool,
}
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
var toolResults []model.ToolResult
for _, toolCall := range toolCalls {
var tool *model.Tool
for _, available := range toolBag {
if available.Name == toolCall.Name {
tool = &available
break
}
}
if tool == nil {
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
}
// Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil {
// This can happen if the model missed or supplied invalid tool args
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
}
toolResult := model.ToolResult{
ToolCallID: toolCall.ID,
ToolName: toolCall.Name,
Result: result,
}
toolResults = append(toolResults, toolResult)
}
return toolResults, nil
}

View File

@ -1,11 +1,11 @@
package toolbox package tools
import ( import (
"fmt" "fmt"
"os" "os"
toolutil "git.mlow.ca/mlow/lmcli/pkg/agent/toolbox/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/api" toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
) )
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory. const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
@ -15,10 +15,10 @@ Example result:
"message": "success" "message": "success"
}` }`
var WriteFileTool = api.ToolSpec{ var WriteFileTool = model.Tool{
Name: "write_file", Name: "write_file",
Description: WRITE_FILE_DESCRIPTION, Description: WRITE_FILE_DESCRIPTION,
Parameters: []api.ToolParameter{ Parameters: []model.ToolParameter{
{ {
Name: "path", Name: "path",
Type: "string", Type: "string",
@ -32,7 +32,7 @@ var WriteFileTool = api.ToolSpec{
Required: true, Required: true,
}, },
}, },
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) { Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
tmp, ok := args["path"] tmp, ok := args["path"]
if !ok { if !ok {
return "", fmt.Errorf("Path parameter to write_file was not included.") return "", fmt.Errorf("Path parameter to write_file was not included.")
@ -58,14 +58,14 @@ var WriteFileTool = api.ToolSpec{
}, },
} }
func writeFile(path string, content string) api.CallResult { func writeFile(path string, content string) model.CallResult {
ok, reason := toolutil.IsPathWithinCWD(path) ok, reason := toolutil.IsPathWithinCWD(path)
if !ok { if !ok {
return api.CallResult{Message: reason} return model.CallResult{Message: reason}
} }
err := os.WriteFile(path, []byte(content), 0644) err := os.WriteFile(path, []byte(content), 0644)
if err != nil { if err != nil {
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())} return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
} }
return api.CallResult{} return model.CallResult{}
} }

View File

@ -4,6 +4,7 @@ import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
@ -15,39 +16,37 @@ import (
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM
msgResponseChunk api.Chunk
// sent when response is finished being received
msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation
msgResponseError error
// sent on each completed reply
msgResponse models.Message
// sent when a conversation is (re)loaded // sent when a conversation is (re)loaded
msgConversationLoaded struct { msgConversationLoaded struct {
conversation *api.Conversation conversation *models.Conversation
rootMessages []api.Message rootMessages []models.Message
} }
// sent when a new conversation title generated // sent when a new conversation title generated
msgConversationTitleGenerated string msgConversationTitleGenerated string
// sent when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when the conversation has been persisted, triggers a reload of contents // sent when the conversation has been persisted, triggers a reload of contents
msgConversationPersisted struct { msgConversationPersisted struct {
isNew bool isNew bool
conversation *api.Conversation conversation *models.Conversation
messages []api.Message messages []models.Message
} }
// sent when a conversation's messages are laoded
msgMessagesLoaded []api.Message
// a special case of common.MsgError that stops the response waiting animation
msgChatResponseError error
// sent on each chunk received from LLM
msgChatResponseChunk api.Chunk
// sent on each completed reply
msgChatResponse *api.Message
// sent when the response is canceled
msgChatResponseCanceled struct{}
// sent when results from a tool call are returned
msgToolResults []api.ToolResult
// sent when the given message is made the new selected reply of its parent // sent when the given message is made the new selected reply of its parent
msgSelectedReplyCycled *api.Message msgSelectedReplyCycled *models.Message
// sent when the given message is made the new selected root of the current conversation // sent when the given message is made the new selected root of the current conversation
msgSelectedRootCycled *api.Message msgSelectedRootCycled *models.Message
// sent when a message's contents are updated and saved // sent when a message's contents are updated and saved
msgMessageUpdated *api.Message msgMessageUpdated *models.Message
// sent when a message is cloned, with the cloned message // sent when a message is cloned, with the cloned message
msgMessageCloned *api.Message msgMessageCloned *models.Message
) )
type focusState int type focusState int
@ -78,14 +77,14 @@ type Model struct {
// app state // app state
state state // current overall status of the view state state // current overall status of the view
conversation *api.Conversation conversation *models.Conversation
rootMessages []api.Message rootMessages []models.Message
messages []api.Message messages []models.Message
selectedMessage int selectedMessage int
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan api.Message replyChan chan models.Message
chatReplyChunks chan api.Chunk replyChunkChan chan api.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
@ -112,12 +111,12 @@ func Chat(shared shared.Shared) Model {
Shared: shared, Shared: shared,
state: idle, state: idle,
conversation: &api.Conversation{}, conversation: &models.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan api.Message), replyChan: make(chan models.Message),
chatReplyChunks: make(chan api.Chunk), replyChunkChan: make(chan api.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,
@ -145,8 +144,8 @@ func Chat(shared shared.Shared) Model {
system := shared.Ctx.GetSystemPrompt() system := shared.Ctx.GetSystemPrompt()
if system != "" { if system != "" {
m.messages = []api.Message{{ m.messages = []models.Message{{
Role: api.MessageRoleSystem, Role: models.MessageRoleSystem,
Content: system, Content: system,
}} }}
} }
@ -167,5 +166,6 @@ func Chat(shared shared.Shared) Model {
func (m Model) Init() tea.Cmd { func (m Model) Init() tea.Cmd {
return tea.Batch( return tea.Batch(
m.waitForResponseChunk(), m.waitForResponseChunk(),
m.waitForResponse(),
) )
} }

View File

@ -2,18 +2,16 @@ package chat
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/agent"
"git.mlow.ca/mlow/lmcli/pkg/api"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
) )
func (m *Model) setMessage(i int, msg api.Message) { func (m *Model) setMessage(i int, msg models.Message) {
if i >= len(m.messages) { if i >= len(m.messages) {
panic("i out of range") panic("i out of range")
} }
@ -21,7 +19,7 @@ func (m *Model) setMessage(i int, msg api.Message) {
m.messageCache[i] = m.renderMessage(i) m.messageCache[i] = m.renderMessage(i)
} }
func (m *Model) addMessage(msg api.Message) { func (m *Model) addMessage(msg models.Message) {
m.messages = append(m.messages, msg) m.messages = append(m.messages, msg)
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1)) m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
} }
@ -90,7 +88,7 @@ func (m *Model) generateConversationTitle() tea.Cmd {
} }
} }
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd { func (m *Model) updateConversationTitle(conversation *models.Conversation) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateConversation(conversation) err := m.Shared.Ctx.Store.UpdateConversation(conversation)
if err != nil { if err != nil {
@ -103,7 +101,7 @@ func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd
// Clones the given message (and its descendents). If selected is true, updates // Clones the given message (and its descendents). If selected is true, updates
// either its parent's SelectedReply or its conversation's SelectedRoot to // either its parent's SelectedReply or its conversation's SelectedRoot to
// point to the new clone // point to the new clone
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd { func (m *Model) cloneMessage(message models.Message, selected bool) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
msg, _, err := m.Ctx.Store.CloneBranch(message) msg, _, err := m.Ctx.Store.CloneBranch(message)
if err != nil { if err != nil {
@ -125,7 +123,7 @@ func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
} }
} }
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd { func (m *Model) updateMessageContent(message *models.Message) tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
err := m.Shared.Ctx.Store.UpdateMessage(message) err := m.Shared.Ctx.Store.UpdateMessage(message)
if err != nil { if err != nil {
@ -135,7 +133,7 @@ func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
} }
} }
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) { func cycleSelectedMessage(selected *models.Message, choices []models.Message, dir MessageCycleDirection) (*models.Message, error) {
currentIndex := -1 currentIndex := -1
for i, reply := range choices { for i, reply := range choices {
if reply.ID == selected.ID { if reply.ID == selected.ID {
@ -160,7 +158,7 @@ func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir Mess
return &choices[next], nil return &choices[next], nil
} }
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd { func (m *Model) cycleSelectedRoot(conv *models.Conversation, dir MessageCycleDirection) tea.Cmd {
if len(m.rootMessages) < 2 { if len(m.rootMessages) < 2 {
return nil return nil
} }
@ -180,7 +178,7 @@ func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirect
} }
} }
func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd { func (m *Model) cycleSelectedReply(message *models.Message, dir MessageCycleDirection) tea.Cmd {
if len(message.Replies) < 2 { if len(message.Replies) < 2 {
return nil return nil
} }
@ -220,12 +218,15 @@ func (m *Model) persistConversation() tea.Cmd {
// else, we'll handle updating an existing conversation's messages // else, we'll handle updating an existing conversation's messages
for i := range messages { for i := range messages {
if messages[i].ID > 0 { if messages[i].ID > 0 {
// message has an ID, update it // message has an ID, update its contents
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i]) err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
if err != nil { if err != nil {
return shared.MsgError(err) return shared.MsgError(err)
} }
} else if i > 0 { } else if i > 0 {
if messages[i].Content == "" {
continue
}
// messages is new, so add it as a reply to previous message // messages is new, so add it as a reply to previous message
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i]) saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
if err != nil { if err != nil {
@ -242,23 +243,13 @@ func (m *Model) persistConversation() tea.Cmd {
} }
} }
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
return func() tea.Msg {
results, err := agent.ExecuteToolCalls(toolCalls, m.Ctx.EnabledTools)
if err != nil {
return shared.MsgError(err)
}
return msgToolResults(results)
}
}
func (m *Model) promptLLM() tea.Cmd { func (m *Model) promptLLM() tea.Cmd {
m.state = pendingResponse m.state = pendingResponse
m.replyCursor.Blink = false m.replyCursor.Blink = false
m.tokenCount = 0
m.startTime = time.Now() m.startTime = time.Now()
m.elapsed = 0 m.elapsed = 0
m.tokenCount = 0
return func() tea.Msg { return func() tea.Msg {
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model) model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
@ -266,34 +257,36 @@ func (m *Model) promptLLM() tea.Cmd {
return shared.MsgError(err) return shared.MsgError(err)
} }
requestParams := api.RequestParameters{ requestParams := models.RequestParameters{
Model: model, Model: model,
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens, MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature, Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
ToolBag: m.Shared.Ctx.EnabledTools, ToolBag: m.Shared.Ctx.EnabledTools,
} }
replyHandler := func(msg models.Message) {
m.replyChan <- msg
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
canceled := false
go func() { go func() {
select { select {
case <-m.stopSignal: case <-m.stopSignal:
canceled = true
cancel() cancel()
} }
}() }()
resp, err := provider.CreateChatCompletionStream( resp, err := provider.CreateChatCompletionStream(
ctx, requestParams, m.messages, m.chatReplyChunks, ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
) )
if errors.Is(err, context.Canceled) { if err != nil && !canceled {
return msgChatResponseCanceled(struct{}{}) return msgResponseError(err)
} }
if err != nil { return msgResponseEnd(resp)
return msgChatResponseError(err)
}
return msgChatResponse(resp)
} }
} }

View File

@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
tea "github.com/charmbracelet/bubbletea" tea "github.com/charmbracelet/bubbletea"
@ -150,12 +150,12 @@ func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
return true, nil return true, nil
} }
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser { if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message")) return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
} }
m.addMessage(api.Message{ m.addMessage(models.Message{
Role: api.MessageRoleUser, Role: models.MessageRoleUser,
Content: input, Content: input,
}) })

View File

@ -4,7 +4,7 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
@ -21,9 +21,15 @@ func (m *Model) HandleResize(width, height int) {
} }
} }
func (m *Model) waitForResponse() tea.Cmd {
return func() tea.Msg {
return msgResponse(<-m.replyChan)
}
}
func (m *Model) waitForResponseChunk() tea.Cmd { func (m *Model) waitForResponseChunk() tea.Cmd {
return func() tea.Msg { return func() tea.Msg {
return msgChatResponseChunk(<-m.chatReplyChunks) return msgResponseChunk(<-m.replyChunkChan)
} }
} }
@ -42,7 +48,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname { if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
// clear existing messages if we're loading a new conversation // clear existing messages if we're loading a new conversation
m.messages = []api.Message{} m.messages = []models.Message{}
m.selectedMessage = 0 m.selectedMessage = 0
} }
} }
@ -81,7 +87,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()
case msgChatResponseChunk: case msgResponseChunk:
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
if msg.Content == "" { if msg.Content == "" {
@ -93,9 +99,9 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
// append chunk to existing message // append chunk to existing message
m.setMessageContents(last, m.messages[last].Content+msg.Content) m.setMessageContents(last, m.messages[last].Content+msg.Content)
} else { } else {
// use chunk in a new message // use chunk in new message
m.addMessage(api.Message{ m.addMessage(models.Message{
Role: api.MessageRoleAssistant, Role: models.MessageRoleAssistant,
Content: msg.Content, Content: msg.Content,
}) })
} }
@ -107,10 +113,10 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.tokenCount += msg.TokenCount m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime) m.elapsed = time.Now().Sub(m.startTime)
case msgChatResponse: case msgResponse:
m.state = idle cmds = append(cmds, m.waitForResponse()) // wait for the next response
reply := (*api.Message)(msg) reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content) reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1 last := len(m.messages) - 1
@ -118,18 +124,11 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
panic("Unexpected empty messages handling msgAssistantReply") panic("Unexpected empty messages handling msgAssistantReply")
} }
if m.messages[last].Role.IsAssistant() { if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
// TODO: handle continuations gracefully - some models support them well, others fail horribly. // this was a continuation, so replace the previous message with the completed reply
m.setMessage(last, *reply) m.setMessage(last, reply)
} else { } else {
m.addMessage(*reply) m.addMessage(reply)
}
switch reply.Role {
case api.MessageRoleToolCall:
// TODO: user confirmation before execution
// m.state = waitingForConfirmation
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
} }
if m.persistence { if m.persistence {
@ -141,32 +140,17 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
} }
m.updateContent() m.updateContent()
case msgChatResponseCanceled: case msgResponseEnd:
m.state = idle m.state = idle
m.updateContent()
case msgChatResponseError:
m.state = idle
m.Shared.Err = error(msg)
m.updateContent()
case msgToolResults:
last := len(m.messages) - 1 last := len(m.messages) - 1
if last < 0 { if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply") panic("Unexpected empty messages handling msgResponseEnd")
} }
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
if m.messages[last].Role != api.MessageRoleToolCall { m.updateContent()
panic("Previous message not a tool call, unexpected") case msgResponseError:
} m.state = idle
m.Shared.Err = error(msg)
m.addMessage(api.Message{
Role: api.MessageRoleToolResult,
ToolResults: api.ToolResults(msg),
})
if m.persistence {
cmds = append(cmds, m.persistConversation())
}
m.updateContent() m.updateContent()
case msgConversationTitleGenerated: case msgConversationTitleGenerated:
title := string(msg) title := string(msg)
@ -177,15 +161,13 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
case cursor.BlinkMsg: case cursor.BlinkMsg:
if m.state == pendingResponse { if m.state == pendingResponse {
// ensure we show the updated "wait for response" cursor blink state // ensure we show the updated "wait for response" cursor blink state
last := len(m.messages)-1
m.messageCache[last] = m.renderMessage(last)
m.updateContent() m.updateContent()
} }
case msgConversationPersisted: case msgConversationPersisted:
m.conversation = msg.conversation m.conversation = msg.conversation
m.messages = msg.messages m.messages = msg.messages
if msg.isNew { if msg.isNew {
m.rootMessages = []api.Message{m.messages[0]} m.rootMessages = []models.Message{m.messages[0]}
} }
m.rebuildMessageCache() m.rebuildMessageCache()
m.updateContent() m.updateContent()

View File

@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
@ -63,22 +63,22 @@ func (m Model) View() string {
return lipgloss.JoinVertical(lipgloss.Left, sections...) return lipgloss.JoinVertical(lipgloss.Left, sections...)
} }
func (m *Model) renderMessageHeading(i int, message *api.Message) string { func (m *Model) renderMessageHeading(i int, message *models.Message) string {
icon := "" icon := ""
friendly := message.Role.FriendlyRole() friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Faint(true).Bold(true) style := lipgloss.NewStyle().Faint(true).Bold(true)
switch message.Role { switch message.Role {
case api.MessageRoleSystem: case models.MessageRoleSystem:
icon = "⚙️" icon = "⚙️"
case api.MessageRoleUser: case models.MessageRoleUser:
style = userStyle style = userStyle
case api.MessageRoleAssistant: case models.MessageRoleAssistant:
style = assistantStyle style = assistantStyle
case api.MessageRoleToolCall: case models.MessageRoleToolCall:
style = assistantStyle style = assistantStyle
friendly = api.MessageRoleAssistant.FriendlyRole() friendly = models.MessageRoleAssistant.FriendlyRole()
case api.MessageRoleToolResult: case models.MessageRoleToolResult:
icon = "🔧" icon = "🔧"
} }
@ -124,9 +124,6 @@ func (m *Model) renderMessageHeading(i int, message *api.Message) string {
return messageHeadingStyle.Render(prefix + user + suffix) return messageHeadingStyle.Render(prefix + user + suffix)
} }
// renderMessages renders the message at the given index as it should be shown
// *at this moment* - we render differently depending on the current application
// state (window size, etc, etc).
func (m *Model) renderMessage(i int) string { func (m *Model) renderMessage(i int) string {
msg := &m.messages[i] msg := &m.messages[i]
@ -141,33 +138,33 @@ func (m *Model) renderMessage(i int) string {
} }
} }
isLast := i == len(m.messages)-1
isAssistant := msg.Role == api.MessageRoleAssistant
if m.state == pendingResponse && isLast && isAssistant {
// Show the assistant's cursor // Show the assistant's cursor
if m.state == pendingResponse && i == len(m.messages)-1 && msg.Role == models.MessageRoleAssistant {
sb.WriteString(m.replyCursor.View()) sb.WriteString(m.replyCursor.View())
} }
// Write tool call info // Write tool call info
var toolString string var toolString string
switch msg.Role { switch msg.Role {
case api.MessageRoleToolCall: case models.MessageRoleToolCall:
bytes, err := yaml.Marshal(msg.ToolCalls) bytes, err := yaml.Marshal(msg.ToolCalls)
if err != nil { if err != nil {
toolString = "Could not serialize ToolCalls" toolString = "Could not serialize ToolCalls"
} else { } else {
toolString = "tool_calls:\n" + string(bytes) toolString = "tool_calls:\n" + string(bytes)
} }
case api.MessageRoleToolResult: case models.MessageRoleToolResult:
if !m.showToolResults {
break
}
type renderedResult struct { type renderedResult struct {
ToolName string `yaml:"tool"` ToolName string `yaml:"tool"`
Result any `yaml:"result,omitempty"` Result any
} }
var toolResults []renderedResult var toolResults []renderedResult
for _, result := range msg.ToolResults { for _, result := range msg.ToolResults {
if m.showToolResults {
var jsonResult interface{} var jsonResult interface{}
err := json.Unmarshal([]byte(result.Result), &jsonResult) err := json.Unmarshal([]byte(result.Result), &jsonResult)
if err != nil { if err != nil {
@ -183,13 +180,6 @@ func (m *Model) renderMessage(i int) string {
Result: &jsonResult, Result: &jsonResult,
}) })
} }
} else {
// Only show the tool name when results are hidden
toolResults = append(toolResults, renderedResult{
ToolName: result.ToolName,
Result: "(hidden, press ctrl+t to view)",
})
}
} }
bytes, err := yaml.Marshal(toolResults) bytes, err := yaml.Marshal(toolResults)
@ -230,21 +220,40 @@ func (m *Model) conversationMessagesView() string {
for i, message := range m.messages { for i, message := range m.messages {
m.messageOffsets[i] = lineCnt m.messageOffsets[i] = lineCnt
switch message.Role {
case models.MessageRoleToolCall:
if !m.showToolResults && message.Content == "" {
continue
}
case models.MessageRoleToolResult:
if !m.showToolResults {
continue
}
}
heading := m.renderMessageHeading(i, &message) heading := m.renderMessageHeading(i, &message)
sb.WriteString(heading) sb.WriteString(heading)
sb.WriteString("\n") sb.WriteString("\n")
lineCnt += lipgloss.Height(heading) lineCnt += lipgloss.Height(heading)
rendered := m.messageCache[i] var rendered string
if m.state == pendingResponse && i == len(m.messages)-1 {
// do a direct render of final (assistant) message to handle the
// assistant cursor blink
rendered = m.renderMessage(i)
} else {
rendered = m.messageCache[i]
}
sb.WriteString(rendered) sb.WriteString(rendered)
sb.WriteString("\n") sb.WriteString("\n")
lineCnt += lipgloss.Height(rendered) lineCnt += lipgloss.Height(rendered)
} }
// Render a placeholder for the incoming assistant reply // Render a placeholder for the incoming assistant reply
if m.state == pendingResponse && m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant { if m.state == pendingResponse && (len(m.messages) == 0 || m.messages[len(m.messages)-1].Role != models.MessageRoleAssistant) {
heading := m.renderMessageHeading(-1, &api.Message{ heading := m.renderMessageHeading(-1, &models.Message{
Role: api.MessageRoleAssistant, Role: models.MessageRoleAssistant,
}) })
sb.WriteString(heading) sb.WriteString(heading)
sb.WriteString("\n") sb.WriteString("\n")

View File

@ -5,7 +5,7 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"git.mlow.ca/mlow/lmcli/pkg/tui/styles" "git.mlow.ca/mlow/lmcli/pkg/tui/styles"
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util" tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
@ -16,15 +16,15 @@ import (
) )
type loadedConversation struct { type loadedConversation struct {
conv api.Conversation conv models.Conversation
lastReply api.Message lastReply models.Message
} }
type ( type (
// sent when conversation list is loaded // sent when conversation list is loaded
msgConversationsLoaded ([]loadedConversation) msgConversationsLoaded ([]loadedConversation)
// sent when a conversation is selected // sent when a conversation is selected
msgConversationSelected api.Conversation msgConversationSelected models.Conversation
) )
type Model struct { type Model struct {