Compare commits

..

No commits in common. "c51644e78e6821647141648ea37f49ec2923d94a" and "32eab7aa3564ac770faff01b78c6fcf2c5312c82" have entirely different histories.

14 changed files with 70 additions and 1285 deletions

10
go.mod
View File

@ -4,8 +4,6 @@ go 1.21
require ( require (
github.com/alecthomas/chroma/v2 v2.11.1 github.com/alecthomas/chroma/v2 v2.11.1
github.com/charmbracelet/bubbles v0.18.0
github.com/charmbracelet/bubbletea v0.25.0
github.com/charmbracelet/lipgloss v0.10.0 github.com/charmbracelet/lipgloss v0.10.0
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/sashabaranov/go-openai v1.17.7 github.com/sashabaranov/go-openai v1.17.7
@ -16,9 +14,7 @@ require (
) )
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
@ -26,19 +22,13 @@ require (
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.14.0 // indirect golang.org/x/sys v0.14.0 // indirect
golang.org/x/term v0.6.0 // indirect
golang.org/x/text v0.3.8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect
) )

21
go.sum
View File

@ -4,18 +4,10 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s= github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE= github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
@ -38,17 +30,11 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
@ -69,16 +55,9 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -1,37 +0,0 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui"
"github.com/spf13/cobra"
)
func ChatCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "chat [conversation]",
Short: "Open the chat interface",
Long: `Open the chat interface, optionally on a given conversation.`,
RunE: func(cmd *cobra.Command, args []string) error {
// TODO: implement jump-to-conversation logic
shortname := ""
if len(args) == 1 {
shortname = args[0]
}
err := tui.Launch(ctx, shortname)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
return cmd
}

View File

@ -23,7 +23,6 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
chatCmd := ChatCmd(ctx)
continueCmd := ContinueCmd(ctx) continueCmd := ContinueCmd(ctx)
cloneCmd := CloneCmd(ctx) cloneCmd := CloneCmd(ctx)
editCmd := EditCmd(ctx) editCmd := EditCmd(ctx)
@ -49,7 +48,6 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
} }
root.AddCommand( root.AddCommand(
chatCmd,
cloneCmd, cloneCmd,
continueCmd, continueCmd,
editCmd, editCmd,

View File

@ -118,18 +118,11 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
func FormatForExternalPrompt(messages []model.Message, system bool) string { func FormatForExternalPrompt(messages []model.Message, system bool) string {
sb := strings.Builder{} sb := strings.Builder{}
for _, message := range messages { for _, message := range messages {
if message.Content == "" { if message.Role != model.MessageRoleUser && (message.Role != model.MessageRoleSystem || !system) {
continue continue
} }
switch message.Role { sb.WriteString(fmt.Sprintf("<%s>\n", message.Role.FriendlyRole()))
case model.MessageRoleAssistant, model.MessageRoleToolCall: sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.Content))
sb.WriteString("Assistant:\n\n")
case model.MessageRoleUser:
sb.WriteString("User:\n\n")
default:
continue
}
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
} }
return sb.String() return sb.String()
} }
@ -140,32 +133,13 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
return "", err return "", err
} }
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective. const header = "Generate a concise 4-5 word title for the conversation below."
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, FormatForExternalPrompt(messages, false))
Example conversation:
"""
User:
Hello!
Assistant:
Hello! How may I assist you?
"""
Example response:
"""
Title: A brief introduction
"""
`
conversation := FormatForExternalPrompt(messages, false)
generateRequest := []model.Message{ generateRequest := []model.Message{
{ {
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n%s", conversation, prompt), Content: prompt,
}, },
} }
@ -184,15 +158,12 @@ Title: A brief introduction
return "", err return "", err
} }
response = strings.TrimPrefix(response, "Title: ")
response = strings.Trim(response, "\"")
return response, nil return response, nil
} }
// ShowWaitAnimation prints an animated ellipses to stdout until something is // ShowWaitAnimation prints an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to // received on the signal channel. An empty string sent to the channel to
// notify the caller that the animation has completed (carriage returned). // noftify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) { func ShowWaitAnimation(signal chan any) {
// Save the current cursor position // Save the current cursor position
fmt.Print("\033[s") fmt.Print("\033[s")

View File

@ -41,28 +41,18 @@ type RequestParameters struct {
ToolBag []Tool ToolBag []Tool
} }
func (m *MessageRole) IsAssistant() bool {
switch *m {
case MessageRoleAssistant, MessageRoleToolCall:
return true
}
return false
}
// FriendlyRole returns a human friendly signifier for the message's role. // FriendlyRole returns a human friendly signifier for the message's role.
func (m *MessageRole) FriendlyRole() string { func (m *MessageRole) FriendlyRole() string {
var friendlyRole string
switch *m { switch *m {
case MessageRoleUser: case MessageRoleUser:
return "You" friendlyRole = "You"
case MessageRoleSystem: case MessageRoleSystem:
return "System" friendlyRole = "System"
case MessageRoleAssistant: case MessageRoleAssistant:
return "Assistant" friendlyRole = "Assistant"
case MessageRoleToolCall:
return "Tool Call"
case MessageRoleToolResult:
return "Tool Result"
default: default:
return string(*m) friendlyRole = string(*m)
} }
return friendlyRole
} }

View File

@ -46,8 +46,6 @@ type Response struct {
Type string `json:"type"` Type string `json:"type"`
Role string `json:"role"` Role string `json:"role"`
Content []OriginalContent `json:"content"` Content []OriginalContent `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence string `json:"stop_sequence"`
} }
const FUNCTION_STOP_SEQUENCE = "</function_calls>" const FUNCTION_STOP_SEQUENCE = "</function_calls>"
@ -70,7 +68,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
startIdx := 0 startIdx := 0
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem { if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
requestBody.System = messages[0].Content requestBody.System = messages[0].Content
requestBody.Messages = requestBody.Messages[1:] requestBody.Messages = requestBody.Messages[:len(messages)-1]
startIdx = 1 startIdx = 1
} }
@ -96,11 +94,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
if err != nil { if err != nil {
panic("Could not serialize []ToolCall to XMLFunctionCall") panic("Could not serialize []ToolCall to XMLFunctionCall")
} }
if len(message.Content) > 0 { message.Content += xmlString
message.Content += fmt.Sprintf("\n\n%s", xmlString)
} else {
message.Content = xmlString
}
case model.MessageRoleToolResult: case model.MessageRoleToolResult:
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults) xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
xmlString, err := xmlFuncResults.XMLString() xmlString, err := xmlFuncResults.XMLString()
@ -149,10 +143,6 @@ func (c *AnthropicClient) CreateChatCompletion(
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(ctx, c, request)
@ -168,14 +158,6 @@ func (c *AnthropicClient) CreateChatCompletion(
} }
sb := strings.Builder{} sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
}
for _, content := range response.Content { for _, content := range response.Content {
var reply model.Message var reply model.Message
switch content.Type { switch content.Type {
@ -203,10 +185,6 @@ func (c *AnthropicClient) CreateChatCompletionStream(
callback provider.ReplyCallback, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
request := buildRequest(params, messages) request := buildRequest(params, messages)
request.Stream = true request.Stream = true
@ -216,18 +194,9 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
defer resp.Body.Close() defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
sb := strings.Builder{} sb := strings.Builder{}
lastMessage := messages[len(messages)-1]
continuation := false
if messages[len(messages)-1].Role.IsAssistant() {
// this is a continuation of a previous assistant reply, so we'll
// include its contents in the final result
sb.WriteString(lastMessage.Content)
continuation = true
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@ -302,21 +271,24 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found") return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
} }
funcCallXml := content[start:]
funcCallXml += FUNCTION_STOP_SEQUENCE
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- FUNCTION_STOP_SEQUENCE output <- FUNCTION_STOP_SEQUENCE
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE // Extract function calls
var functionCalls XMLFunctionCalls var functionCalls XMLFunctionCalls
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls) err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err) return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
} }
// Execute function calls
toolCall := model.Message{ toolCall := model.Message{
Role: model.MessageRoleToolCall, Role: model.MessageRoleToolCall,
// function call xml stripped from content for model interop // xml stripped from content
Content: strings.TrimSpace(content[:start]), Content: content[:start],
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls), ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
} }
@ -325,36 +297,31 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", err return "", err
} }
toolResult := model.Message{ toolReply := model.Message{
Role: model.MessageRoleToolResult, Role: model.MessageRoleToolResult,
ToolResults: toolResults, ToolResults: toolResults,
} }
if callback != nil { if callback != nil {
callback(toolCall) callback(toolCall)
callback(toolResult) callback(toolReply)
} }
if continuation { // Recurse into CreateChatCompletionStream with the tool call replies
messages[len(messages)-1] = toolCall // added to the original messages
} else { messages = append(append(messages, toolCall), toolReply)
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
content := sb.String()
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content, Content: sb.String(),
}) })
} }
return content, nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
default: default:

View File

@ -9,10 +9,9 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
const TOOL_PREAMBLE = `You have access to the following tools when replying. const TOOL_PREAMBLE = `In this environment you have access to a set of tools which may assist you in fulfilling user requests.
You may call them like this: You may call them like this:
<function_calls> <function_calls>
<invoke> <invoke>
<tool_name>$TOOL_NAME</tool_name> <tool_name>$TOOL_NAME</tool_name>
@ -25,14 +24,6 @@ You may call them like this:
Here are the tools available:` Here are the tools available:`
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
type XMLTools struct { type XMLTools struct {
XMLName struct{} `xml:"tools"` XMLName struct{} `xml:"tools"`
ToolDescriptions []XMLToolDescription `xml:"tool_description"` ToolDescriptions []XMLToolDescription `xml:"tool_description"`
@ -160,7 +151,7 @@ func buildToolsSystemPrompt(tools []model.Tool) string {
if err != nil { if err != nil {
panic("Could not serialize []model.Tool to XMLTools") panic("Could not serialize []model.Tool to XMLTools")
} }
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER return TOOL_PREAMBLE + "\n" + xmlToolsString + "\n"
} }
func (x XMLTools) XMLString() (string, error) { func (x XMLTools) XMLString() (string, error) {

View File

@ -137,15 +137,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []openai.ToolCall, toolCalls []openai.ToolCall,
callback provider.ReplyCallback,
messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1]
continuation := false
if lastMessage.Role.IsAssistant() {
continuation = true
}
toolCall := model.Message{ toolCall := model.Message{
Role: model.MessageRoleToolCall, Role: model.MessageRoleToolCall,
Content: content, Content: content,
@ -162,19 +154,7 @@ func handleToolCalls(
ToolResults: toolResults, ToolResults: toolResults,
} }
if callback != nil { return []model.Message{toolCall, toolResult}, nil
callback(toolCall)
callback(toolResult)
}
if continuation {
messages[len(messages)-1] = toolCall
} else {
messages = append(messages, toolCall)
}
messages = append(messages, toolResult)
return messages, nil
} }
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
@ -183,10 +163,6 @@ func (c *OpenAIClient) CreateChatCompletion(
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req) resp, err := client.CreateChatCompletion(ctx, req)
@ -196,46 +172,41 @@ func (c *OpenAIClient) CreateChatCompletion(
choice := resp.Choices[0] choice := resp.Choices[0]
var content string
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content = lastMessage.Content + choice.Message.Content
} else {
content = choice.Message.Content
}
toolCalls := choice.Message.ToolCalls toolCalls := choice.Message.ToolCalls
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content, toolCalls, callback, messages) results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
if err != nil { if err != nil {
return content, err return "", err
}
if callback != nil {
for _, result := range results {
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback) return c.CreateChatCompletion(ctx, params, messages, callback)
} }
if callback != nil { if callback != nil {
callback(model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content, Content: choice.Message.Content,
}) })
} }
// Return the user-facing message. // Return the user-facing message.
return content, nil return choice.Message.Content, nil
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callbback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages")
}
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
@ -248,11 +219,6 @@ func (c *OpenAIClient) CreateChatCompletionStream(
content := strings.Builder{} content := strings.Builder{}
toolCalls := []openai.ToolCall{} toolCalls := []openai.ToolCall{}
lastMessage := messages[len(messages)-1]
if lastMessage.Role.IsAssistant() {
content.WriteString(lastMessage.Content)
}
// Iterate stream segments // Iterate stream segments
for { for {
response, e := stream.Recv() response, e := stream.Recv()
@ -285,21 +251,28 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
messages, err := handleToolCalls(params, content.String(), toolCalls, callback, messages) results, err := handleToolCalls(params, content.String(), toolCalls)
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if callbback != nil {
for _, result := range results {
callbback(result)
}
}
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) messages = append(messages, results...)
} else { return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
if callback != nil { }
callback(model.Message{
if callbback != nil {
callbback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })
} }
}
return content.String(), err return content.String(), err
} }

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
sqids "github.com/sqids/sqids-go" sqids "github.com/sqids/sqids-go"
@ -27,7 +26,6 @@ type ConversationStore interface {
SaveMessage(message *model.Message) error SaveMessage(message *model.Message) error
DeleteMessage(message *model.Message) error DeleteMessage(message *model.Message) error
UpdateMessage(message *model.Message) error UpdateMessage(message *model.Message) error
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
} }
type SQLStore struct { type SQLStore struct {
@ -121,12 +119,3 @@ func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
return &message, err return &message, err
} }
// AddReply adds the given messages as a reply to the given conversation, can be
// used to easily copy a message associated with one conversation, to another
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
m.ConversationID = c.ID
m.ID = 0
m.CreatedAt = time.Time{}
return &m, s.SaveMessage(&m)
}

View File

@ -1,143 +0,0 @@
package tools
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
)
const TREE_DESCRIPTION = `Retrieve a tree view of a directory's contents.
Example result:
{
"message": "success",
"result": ".
a_directory/
file1.txt (100 bytes)
file2.txt (200 bytes)
a_file.txt (123 bytes)
another_file.txt (456 bytes)"
}
`
var DirTreeTool = model.Tool{
Name: "dir_tree",
Description: TREE_DESCRIPTION,
Parameters: []model.ToolParameter{
{
Name: "relative_path",
Type: "string",
Description: "If set, display the tree starting from this path relative to the current one.",
},
{
Name: "max_depth",
Type: "integer",
Description: "Maximum depth of recursion. Default is unlimited.",
},
},
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
var relativeDir string
tmp, ok := args["relative_dir"]
if ok {
relativeDir, ok = tmp.(string)
if !ok {
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
}
}
var maxDepth int = -1
tmp, ok = args["max_depth"]
if ok {
maxDepth, ok = tmp.(int)
if !ok {
if tmps, ok := tmp.(string); ok {
tmpi, err := strconv.Atoi(tmps)
maxDepth = tmpi
if err != nil {
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
}
} else {
return "", fmt.Errorf("Invalid max_depth in function arguments: %v", tmp)
}
}
}
result := tree(relativeDir, maxDepth)
ret, err := result.ToJson()
if err != nil {
return "", fmt.Errorf("Could not serialize result: %v", err)
}
return ret, nil
},
}
func tree(path string, maxDepth int) model.CallResult {
if path == "" {
path = "."
}
ok, reason := toolutil.IsPathWithinCWD(path)
if !ok {
return model.CallResult{Message: reason}
}
var treeOutput strings.Builder
treeOutput.WriteString(path + "\n")
err := buildTree(&treeOutput, path, "", maxDepth)
if err != nil {
return model.CallResult{
Message: err.Error(),
}
}
return model.CallResult{Result: treeOutput.String()}
}
func buildTree(output *strings.Builder, path string, prefix string, maxDepth int) error {
files, err := os.ReadDir(path)
if err != nil {
return err
}
for i, file := range files {
if strings.HasPrefix(file.Name(), ".") {
// Skip hidden files and directories
continue
}
isLast := i == len(files)-1
var branch string
if isLast {
branch = "└── "
} else {
branch = "├── "
}
info, _ := file.Info()
size := info.Size()
sizeStr := fmt.Sprintf(" (%d bytes)", size)
output.WriteString(prefix + branch + file.Name())
if file.IsDir() {
output.WriteString("/\n")
if maxDepth != 0 {
var nextPrefix string
if isLast {
nextPrefix = prefix + " "
} else {
nextPrefix = prefix + "│ "
}
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, maxDepth-1)
}
} else {
output.WriteString(sizeStr + "\n")
}
}
return nil
}

View File

@ -2,12 +2,12 @@ package tools
import ( import (
"fmt" "fmt"
"os"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
var AvailableTools map[string]model.Tool = map[string]model.Tool{ var AvailableTools map[string]model.Tool = map[string]model.Tool{
"dir_tree": DirTreeTool,
"read_dir": ReadDirTool, "read_dir": ReadDirTool,
"read_file": ReadFileTool, "read_file": ReadFileTool,
"write_file": WriteFileTool, "write_file": WriteFileTool,
@ -29,6 +29,9 @@ func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name) return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
} }
// TODO: ability to silence this
fmt.Fprintf(os.Stderr, "\nINFO: Executing tool '%s' with args %s\n", toolCall.Name, toolCall.Parameters)
// Execute the tool // Execute the tool
result, err := tool.Impl(tool, toolCall.Parameters) result, err := tool.Impl(tool, toolCall.Parameters)
if err != nil { if err != nil {

View File

@ -1,844 +0,0 @@
package tui
// The terminal UI for lmcli, launched from the `lmcli chat` command
// TODO:
// - conversation list view
// - change model
// - rename conversation
// - set system prompt
// - system prompt library?
import (
"context"
"fmt"
"strings"
"time"
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"github.com/charmbracelet/bubbles/spinner"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/muesli/reflow/wordwrap"
)
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type editorTarget int
const (
input editorTarget = iota
selectedMessage
)
type model struct {
width int
height int
ctx *lmcli.Context
convShortname string
// application state
conversation *models.Conversation
messages []models.Message
waitingForReply bool
editorTarget editorTarget
stopSignal chan interface{}
replyChan chan models.Message
replyChunkChan chan string
persistence bool // whether we will save new messages in the conversation
err error
// ui state
focus focusState
wrap bool // whether message content is wrapped to viewport width
status string // a general status message
highlightCache []string // a cache of syntax highlighted message content
messageOffsets []int
selectedMessage int
// ui elements
content viewport.Model
input textarea.Model
spinner spinner.Model
}
type message struct {
role string
content string
}
// custom tea.Msg types
type (
// sent on each chunk received from LLM
msgResponseChunk string
// sent when response is finished being received
msgResponseEnd string
// a special case of msgError that stops the response waiting animation
msgResponseError error
// sent on each completed reply
msgAssistantReply models.Message
// sent when a conversation is (re)loaded
msgConversationLoaded *models.Conversation
// sent when a new conversation title is set
msgConversationTitleChanged string
// send when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when an error occurs
msgError error
)
// styles
var (
userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12"))
messageStyle = lipgloss.NewStyle().PaddingLeft(2).PaddingRight(2)
headerStyle = lipgloss.NewStyle().
Background(lipgloss.Color("0"))
conversationStyle = lipgloss.NewStyle().
MarginTop(1).
MarginBottom(1)
footerStyle = lipgloss.NewStyle().
BorderTop(true).
BorderStyle(lipgloss.NormalBorder())
)
func (m model) Init() tea.Cmd {
return tea.Batch(
textarea.Blink,
m.spinner.Tick,
m.loadConversation(m.convShortname),
m.waitForChunk(),
m.waitForReply(),
)
}
func wrapError(err error) tea.Cmd {
return func() tea.Msg {
return msgError(err)
}
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case msgTempfileEditorClosed:
contents := string(msg)
switch m.editorTarget {
case input:
m.input.SetValue(contents)
case selectedMessage:
m.setMessageContents(m.selectedMessage, contents)
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
// update persisted message
err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
if err != nil {
cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err)))
}
}
m.updateContent()
}
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c":
if m.waitingForReply {
m.stopSignal <- ""
return m, nil
} else {
return m, tea.Quit
}
case "ctrl+p":
m.persistence = !m.persistence
case "ctrl+w":
m.wrap = !m.wrap
m.updateContent()
case "q":
if m.focus != focusInput {
return m, tea.Quit
}
default:
var inputHandled tea.Cmd
switch m.focus {
case focusInput:
inputHandled = m.handleInputKey(msg)
case focusMessages:
inputHandled = m.handleMessagesKey(msg)
}
if inputHandled != nil {
return m, inputHandled
}
}
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.content.Width = msg.Width
m.content.Height = msg.Height - m.getFixedComponentHeight()
m.input.SetWidth(msg.Width - 1)
m.updateContent()
case msgConversationLoaded:
m.conversation = (*models.Conversation)(msg)
cmds = append(cmds, m.loadMessages(m.conversation))
case msgMessagesLoaded:
m.setMessages(msg)
m.updateContent()
case msgResponseChunk:
chunk := string(msg)
last := len(m.messages) - 1
if last >= 0 && m.messages[last].Role.IsAssistant() {
m.setMessageContents(last, m.messages[last].Content+chunk)
} else {
m.addMessage(models.Message{
Role: models.MessageRoleAssistant,
Content: chunk,
})
}
m.updateContent()
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
case msgAssistantReply:
// the last reply that was being worked on is finished
reply := models.Message(msg)
reply.Content = strings.TrimSpace(reply.Content)
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgAssistantReply")
}
if reply.Role.IsAssistant() && m.messages[last].Role.IsAssistant() {
// this was a continuation, so replace the previous message with the completed reply
m.setMessage(last, reply)
} else {
m.addMessage(reply)
}
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
cmds = append(cmds, wrapError(err))
} else {
cmds = append(cmds, m.persistConversation())
}
}
if m.conversation.Title == "" {
cmds = append(cmds, m.generateConversationTitle())
}
m.updateContent()
cmds = append(cmds, m.waitForReply())
case msgResponseEnd:
m.waitingForReply = false
last := len(m.messages) - 1
if last < 0 {
panic("Unexpected empty messages handling msgResponseEnd")
}
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
m.updateContent()
m.status = "Press ctrl+s to send"
case msgResponseError:
m.waitingForReply = false
m.status = "Press ctrl+s to send"
m.err = error(msg)
case msgConversationTitleChanged:
title := string(msg)
m.conversation.Title = title
if m.persistence {
err := m.ctx.Store.SaveConversation(m.conversation)
if err != nil {
cmds = append(cmds, wrapError(err))
}
}
case msgError:
m.err = error(msg)
}
var cmd tea.Cmd
m.spinner, cmd = m.spinner.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
inputCaptured := false
m.input, cmd = m.input.Update(msg)
if cmd != nil {
inputCaptured = true
cmds = append(cmds, cmd)
}
if !inputCaptured {
m.content, cmd = m.content.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
}
return m, tea.Batch(cmds...)
}
func (m model) View() string {
if m.width == 0 {
// this is the case upon initial startup, but it's also a safe bet that
// we can just skip rendering if the terminal is really 0 width...
// without this, the m.*View() functions may crash
return ""
}
sections := make([]string, 0, 6)
sections = append(sections, m.headerView())
sections = append(sections, m.contentView())
error := m.errorView()
if error != "" {
sections = append(sections, error)
}
sections = append(sections, m.inputView())
sections = append(sections, m.footerView())
return lipgloss.JoinVertical(
lipgloss.Left,
sections...,
)
}
// returns the total height of "fixed" components, which are those which don't
// change height dependent on window size.
func (m *model) getFixedComponentHeight() int {
h := 0
h += m.input.Height()
h += lipgloss.Height(m.headerView())
h += lipgloss.Height(m.footerView())
errorView := m.errorView()
if errorView != "" {
h += lipgloss.Height(errorView)
}
return h
}
func (m *model) headerView() string {
titleStyle := lipgloss.NewStyle().
PaddingLeft(1).
PaddingRight(1).
Bold(true)
var title string
if m.conversation != nil && m.conversation.Title != "" {
title = m.conversation.Title
} else {
title = "Untitled"
}
part := titleStyle.Render(title)
return headerStyle.Width(m.width).Render(part)
}
func (m *model) contentView() string {
return m.content.View()
}
func (m *model) errorView() string {
if m.err == nil {
return ""
}
return lipgloss.NewStyle().
Width(m.width).
AlignHorizontal(lipgloss.Center).
Bold(true).
Foreground(lipgloss.Color("1")).
Render(fmt.Sprintf("%s", m.err))
}
func (m *model) inputView() string {
return m.input.View()
}
func (m *model) footerView() string {
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
segmentSeparator := "|"
savingStyle := segmentStyle.Copy().Bold(true)
saving := ""
if m.persistence {
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
} else {
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
}
status := m.status
if m.waitingForReply {
status += m.spinner.View()
}
leftSegments := []string{
saving,
segmentStyle.Render(status),
}
rightSegments := []string{
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
}
left := strings.Join(leftSegments, segmentSeparator)
right := strings.Join(rightSegments, segmentSeparator)
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
remaining := m.width - totalWidth
var padding string
if remaining > 0 {
padding = strings.Repeat(" ", remaining)
}
footer := left + padding + right
if remaining < 0 {
ellipses := "... "
// this doesn't work very well, due to trying to trim a string with
// ansii chars already in it
footer = footer[:(len(footer)+remaining)-len(ellipses)-3] + ellipses
}
return footerStyle.Width(m.width).Render(footer)
}
func initialModel(ctx *lmcli.Context, convShortname string) model {
m := model{
ctx: ctx,
convShortname: convShortname,
conversation: &models.Conversation{},
persistence: true,
stopSignal: make(chan interface{}),
replyChan: make(chan models.Message),
replyChunkChan: make(chan string),
wrap: true,
selectedMessage: -1,
}
m.content = viewport.New(0, 0)
m.input = textarea.New()
m.input.CharLimit = 0
m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.ShowLineNumbers = false
m.input.SetHeight(4)
m.input.Focus()
m.spinner = spinner.New(spinner.WithSpinner(
spinner.Spinner{
Frames: []string{
". ",
".. ",
"...",
".. ",
". ",
" ",
},
FPS: time.Second / 3,
},
))
m.waitingForReply = false
m.status = "Press ctrl+s to send"
return m
}
// fraction is the fraction of the total screen height into view the offset
// should be scrolled into view. 0.5 = items will be snapped to middle of
// view
func scrollIntoView(vp *viewport.Model, offset int, fraction float32) {
currentOffset := vp.YOffset
if offset >= currentOffset && offset < currentOffset+vp.Height {
return
}
distance := currentOffset - offset
if distance < 0 {
// we should scroll down until it just comes into view
vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1)
} else {
// we should scroll up
vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction))
}
}
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "tab":
m.focus = focusInput
m.updateContent()
m.input.Focus()
case "e":
message := m.messages[m.selectedMessage]
cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
m.editorTarget = selectedMessage
return cmd
case "ctrl+k":
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage--
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+j":
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
m.selectedMessage++
m.updateContent()
offset := m.messageOffsets[m.selectedMessage]
scrollIntoView(&m.content, offset, 0.1)
}
case "ctrl+r":
// resubmit the conversation with all messages up until and including the selected message
if m.waitingForReply || len(m.messages) == 0 {
return nil
}
m.messages = m.messages[:m.selectedMessage+1]
m.highlightCache = m.highlightCache[:m.selectedMessage+1]
m.updateContent()
m.content.GotoBottom()
return m.promptLLM()
}
return nil
}
func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.focus = focusMessages
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
m.selectedMessage = len(m.messages) - 1
}
m.updateContent()
m.input.Blur()
case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value())
if strings.TrimSpace(userInput) == "" {
return nil
}
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
return wrapError(fmt.Errorf("Can't reply to a user message"))
}
reply := models.Message{
Role: models.MessageRoleUser,
Content: userInput,
}
if m.persistence {
var err error
if m.conversation.ID == 0 {
err = m.ctx.Store.SaveConversation(m.conversation)
}
if err != nil {
return wrapError(err)
}
// ensure all messages up to the one we're about to add are persisted
cmd := m.persistConversation()
if cmd != nil {
return cmd
}
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
if err != nil {
return wrapError(err)
}
reply = *savedReply
}
m.input.SetValue("")
m.addMessage(reply)
m.updateContent()
m.content.GotoBottom()
return m.promptLLM()
case "ctrl+e":
cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
m.editorTarget = input
return cmd
}
return nil
}
func (m *model) loadConversation(shortname string) tea.Cmd {
return func() tea.Msg {
if shortname == "" {
return nil
}
c, err := m.ctx.Store.ConversationByShortName(shortname)
if err != nil {
return msgError(fmt.Errorf("Could not lookup conversation: %v", err))
}
if c.ID == 0 {
return msgError(fmt.Errorf("Conversation not found: %s", shortname))
}
return msgConversationLoaded(c)
}
}
func (m *model) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg {
messages, err := m.ctx.Store.Messages(c)
if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
}
}
func (m *model) waitForReply() tea.Cmd {
return func() tea.Msg {
return msgAssistantReply(<-m.replyChan)
}
}
func (m *model) waitForChunk() tea.Cmd {
return func() tea.Msg {
return msgResponseChunk(<-m.replyChunkChan)
}
}
func (m *model) generateConversationTitle() tea.Cmd {
return func() tea.Msg {
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
if err != nil {
return msgError(err)
}
return msgConversationTitleChanged(title)
}
}
func (m *model) promptLLM() tea.Cmd {
m.waitingForReply = true
m.status = "Press ctrl+c to cancel"
return func() tea.Msg {
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
if err != nil {
return msgError(err)
}
requestParams := models.RequestParameters{
Model: *m.ctx.Config.Defaults.Model,
MaxTokens: *m.ctx.Config.Defaults.MaxTokens,
Temperature: *m.ctx.Config.Defaults.Temperature,
ToolBag: m.ctx.EnabledTools,
}
replyHandler := func(msg models.Message) {
m.replyChan <- msg
}
ctx, cancel := context.WithCancel(context.Background())
canceled := false
go func() {
select {
case <-m.stopSignal:
canceled = true
cancel()
}
}()
resp, err := completionProvider.CreateChatCompletionStream(
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
)
if err != nil && !canceled {
return msgResponseError(err)
}
return msgResponseEnd(resp)
}
}
func (m *model) persistConversation() tea.Cmd {
existingMessages, err := m.ctx.Store.Messages(m.conversation)
if err != nil {
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
}
existingById := make(map[uint]*models.Message, len(existingMessages))
for _, msg := range existingMessages {
existingById[msg.ID] = &msg
}
currentById := make(map[uint]*models.Message, len(m.messages))
for _, msg := range m.messages {
currentById[msg.ID] = &msg
}
for _, msg := range existingMessages {
_, ok := currentById[msg.ID]
if !ok {
err := m.ctx.Store.DeleteMessage(&msg)
if err != nil {
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
}
}
}
for i, msg := range m.messages {
if msg.ID > 0 {
exist, ok := existingById[msg.ID]
if ok {
if msg.Content == exist.Content {
continue
}
// update message when contents don't match that of store
err := m.ctx.Store.UpdateMessage(&msg)
if err != nil {
return wrapError(err)
}
} else {
// this would be quite odd... and I'm not sure how to handle
// it at the time of writing this
}
} else {
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
if err != nil {
return wrapError(err)
}
m.setMessage(i, *newMessage)
}
}
return nil
}
func (m *model) setMessages(messages []models.Message) {
m.messages = messages
m.highlightCache = make([]string, len(messages))
for i, msg := range m.messages {
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.highlightCache[i] = highlighted
}
}
func (m *model) setMessage(i int, msg models.Message) {
if i >= len(m.messages) {
panic("i out of range")
}
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.messages[i] = msg
m.highlightCache[i] = highlighted
}
func (m *model) addMessage(msg models.Message) {
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
m.messages = append(m.messages, msg)
m.highlightCache = append(m.highlightCache, highlighted)
}
func (m *model) setMessageContents(i int, content string) {
if i >= len(m.messages) {
panic("i out of range")
}
highlighted, _ := m.ctx.Chroma.HighlightS(content)
m.messages[i].Content = content
m.highlightCache[i] = highlighted
}
func (m *model) updateContent() {
atBottom := m.content.AtBottom()
m.content.SetContent(m.conversationView())
if atBottom {
// if we were at bottom before the update, scroll with the output
m.content.GotoBottom()
}
}
// render the conversation into a string
func (m *model) conversationView() string {
sb := strings.Builder{}
msgCnt := len(m.messages)
m.messageOffsets = make([]int, len(m.messages))
lineCnt := conversationStyle.GetMarginTop()
for i, message := range m.messages {
m.messageOffsets[i] = lineCnt
icon := "⚙️"
friendly := message.Role.FriendlyRole()
style := lipgloss.NewStyle().Bold(true).Faint(true)
switch message.Role {
case models.MessageRoleUser:
icon = ""
style = userStyle
case models.MessageRoleAssistant:
icon = ""
style = assistantStyle
case models.MessageRoleToolCall, models.MessageRoleToolResult:
icon = "🔧"
}
// write message heading with space for content
user := style.Render(icon + friendly)
var prefix string
var suffix string
faint := lipgloss.NewStyle().Faint(true)
if m.focus == focusMessages {
if i == m.selectedMessage {
prefix = "> "
}
suffix += faint.Render(fmt.Sprintf(" (%d/%d)", i+1, msgCnt))
}
if message.ID == 0 {
suffix += faint.Render(" (not saved)")
}
header := lipgloss.NewStyle().PaddingLeft(1).Render(prefix + user + suffix)
sb.WriteString(header)
lineCnt += lipgloss.Height(header)
// TODO: special rendering for tool calls/results?
if message.Content != "" {
sb.WriteString("\n\n")
lineCnt += 1
// write message contents
var highlighted string
if m.highlightCache[i] == "" {
highlighted = message.Content
} else {
highlighted = m.highlightCache[i]
}
var contents string
if m.wrap {
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 2
wrapped := wordwrap.String(highlighted, wrapWidth)
contents = wrapped
} else {
contents = highlighted
}
sb.WriteString(messageStyle.Width(0).Render(contents))
lineCnt += lipgloss.Height(contents)
}
if i < msgCnt-1 {
sb.WriteString("\n\n")
lineCnt += 1
}
}
return conversationStyle.Render(sb.String())
}
func Launch(ctx *lmcli.Context, convShortname string) error {
p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("Error running program: %v", err)
}
return nil
}

View File

@ -1,42 +0,0 @@
package tui
import (
"os"
"os/exec"
"strings"
tea "github.com/charmbracelet/bubbletea"
)
type msgTempfileEditorClosed string
// openTempfileEditor opens an $EDITOR on a new temporary file with the given
// content. Upon closing, the contents of the file are read back returned
// wrapped in a msgTempfileEditorClosed returned by the tea.Cmd
func openTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
msgFile, _ := os.CreateTemp("/tmp", pattern)
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
if err != nil {
return wrapError(err)
}
editor := os.Getenv("EDITOR")
if editor == "" {
editor = "vim"
}
c := exec.Command(editor, msgFile.Name())
return tea.ExecProcess(c, func(err error) tea.Msg {
bytes, err := os.ReadFile(msgFile.Name())
if err != nil {
return msgError(err)
}
fileContents := string(bytes)
if strings.HasPrefix(fileContents, placeholder) {
fileContents = fileContents[len(placeholder):]
}
stripped := strings.Trim(fileContents, "\n \t")
return msgTempfileEditorClosed(stripped)
})
}