Compare commits

...

6 Commits

Author SHA1 Message Date
3f765234de tui: ability to cancel request in flight 2024-03-12 20:41:34 +00:00
21411c2732 tui: add focus switching between input/messages view 2024-03-12 20:40:49 +00:00
99794addee tui: removed confirm before send, dynamic footer
footer now rendered based on model data, instead of being set to a fixed
string
2024-03-12 20:40:49 +00:00
a47c1a76b4 tui: use ctx chroma highlighter 2024-03-12 20:40:49 +00:00
96fdae982e Add initial TUI 2024-03-12 20:40:44 +00:00
91d3c9c2e1 Update ChatCompletionClient
Instead of CreateChatCompletion* accepting a pointer to a slice of reply
messages, it accepts a callback which is called with each successive
reply the conversation.

This gives the caller more flexibility in how it handles replies (e.g.
it can react to them immediately now, instead of waiting for the entire
call to finish)
2024-03-12 20:39:34 +00:00
11 changed files with 480 additions and 45 deletions

10
go.mod
View File

@ -4,6 +4,8 @@ go 1.21
require ( require (
github.com/alecthomas/chroma/v2 v2.11.1 github.com/alecthomas/chroma/v2 v2.11.1
github.com/charmbracelet/bubbles v0.18.0
github.com/charmbracelet/bubbletea v0.25.0
github.com/charmbracelet/lipgloss v0.10.0 github.com/charmbracelet/lipgloss v0.10.0
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/sashabaranov/go-openai v1.17.7 github.com/sashabaranov/go-openai v1.17.7
@ -14,7 +16,9 @@ require (
) )
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
@ -22,13 +26,19 @@ require (
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-isatty v0.0.18 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.14.0 // indirect golang.org/x/sys v0.14.0 // indirect
golang.org/x/term v0.6.0 // indirect
golang.org/x/text v0.3.8 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect
) )

21
go.sum
View File

@ -4,10 +4,18 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw= github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk= github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s= github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE= github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
@ -30,11 +38,17 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98= github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
@ -55,9 +69,16 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

37
pkg/cmd/chat.go Normal file
View File

@ -0,0 +1,37 @@
package cmd
import (
"fmt"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/tui"
"github.com/spf13/cobra"
)
func ChatCmd(ctx *lmcli.Context) *cobra.Command {
cmd := &cobra.Command{
Use: "chat [conversation]",
Short: "Open the chat interface",
Long: `Open the chat interface, optionally on a given conversation.`,
RunE: func(cmd *cobra.Command, args []string) error {
// TODO: implement jump-to-conversation logic
shortname := ""
if len(args) == 1 {
shortname = args[0]
}
err := tui.Launch(ctx, shortname)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}
return nil
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
if len(args) != 0 {
return nil, compMode
}
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
},
}
return cmd
}

View File

@ -23,6 +23,7 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
chatCmd := ChatCmd(ctx)
continueCmd := ContinueCmd(ctx) continueCmd := ContinueCmd(ctx)
cloneCmd := CloneCmd(ctx) cloneCmd := CloneCmd(ctx)
editCmd := EditCmd(ctx) editCmd := EditCmd(ctx)
@ -48,6 +49,7 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
} }
root.AddCommand( root.AddCommand(
chatCmd,
cloneCmd, cloneCmd,
continueCmd, continueCmd,
editCmd, editCmd,

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages) continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages) _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }

View File

@ -15,7 +15,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams // fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages. // the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) { func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -24,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return nil, err return "", err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
@ -34,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream( response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, &apiReplies, content, context.Background(), requestParams, messages, callback, content,
) )
if response != "" { if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
@ -47,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
err = nil err = nil
} }
} }
return response, nil
return apiReplies, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
@ -99,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replies, err := FetchAndShowCompletion(ctx, allMessages) replyCallback := func(reply model.Message) {
if err != nil { if !persist {
lmcli.Fatal("Error fetching LLM response: %v\n", err) return
}
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
} }
if persist { _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
for _, reply := range replies { if err != nil {
reply.ConversationID = c.ID lmcli.Fatal("Error fetching LLM response: %v\n", err)
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
} }
} }

View File

@ -133,7 +133,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -162,7 +162,9 @@ func (c *AnthropicClient) CreateChatCompletion(
default: default:
return "", fmt.Errorf("unsupported message type: %s", content.Type) return "", fmt.Errorf("unsupported message type: %s", content.Type)
} }
*replies = append(*replies, reply) if callback != nil {
callback(reply)
}
} }
return sb.String(), nil return sb.String(), nil
@ -172,7 +174,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
@ -291,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults, ToolResults: toolResults,
} }
if replies != nil { if callback != nil {
*replies = append(append(*replies, toolCall), toolReply) callback(toolCall)
callback(toolReply)
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
messages = append(append(messages, toolCall), toolReply) messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
reply := model.Message{ if callback != nil {
Role: model.MessageRoleAssistant, callback(model.Message{
Content: sb.String(), Role: model.MessageRoleAssistant,
Content: sb.String(),
})
} }
*replies = append(*replies, reply)
return sb.String(), nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,6 +9,7 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
@ -160,7 +161,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback provider.ReplyCallback,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
@ -177,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil { if err != nil {
return "", err return "", err
} }
if results != nil { if callback != nil {
*replies = append(*replies, results...) for _, result := range results {
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies // Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, replies) return c.CreateChatCompletion(ctx, params, messages, callback)
} }
if replies != nil { if callback != nil {
*replies = append(*replies, model.Message{ callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: choice.Message.Content, Content: choice.Message.Content,
}) })
@ -201,7 +204,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callbback provider.ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
@ -252,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if results != nil {
*replies = append(*replies, results...) if callbback != nil {
for _, result := range results {
callbback(result)
}
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, replies, output) return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
} }
if replies != nil { if callbback != nil {
*replies = append(*replies, model.Message{ callbback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })

View File

@ -6,6 +6,8 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type ReplyCallback func(model.Message)
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
@ -14,7 +16,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback ReplyCallback,
) (string, error) ) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
@ -23,7 +25,7 @@ type ChatCompletionClient interface {
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
replies *[]model.Message, callback ReplyCallback,
output chan<- string, output chan<- string,
) (string, error) ) (string, error)
} }

354
pkg/tui/tui.go Normal file
View File

@ -0,0 +1,354 @@
package tui
// The terminal UI for lmcli, launched from the `lmcli chat` command
// TODO:
// - binding to open selected message/input in $EDITOR
import (
"context"
"fmt"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"github.com/charmbracelet/bubbles/textarea"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
type focusState int
const (
focusInput focusState = iota
focusMessages
)
type model struct {
ctx *lmcli.Context
convShortname string
// application state
conversation *models.Conversation
messages []models.Message
waitingForReply bool
replyChan chan string
replyCancelFunc context.CancelFunc
err error
// ui state
focus focusState
status string // a general status message
// ui elements
content viewport.Model
input textarea.Model
}
type message struct {
role string
content string
}
// custom tea.Msg types
type (
// sent on each chunk received from LLM
msgResponseChunk string
// sent when response is finished being received
msgResponseEnd string
// sent when a conversation is (re)loaded
msgConversationLoaded *models.Conversation
// send when a conversation's messages are laoded
msgMessagesLoaded []models.Message
// sent when an error occurs
msgError error
)
// styles
var (
inputStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#ff0000"))
contentStyle = lipgloss.NewStyle().PaddingLeft(2)
userStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("10"))
assistantStyle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12"))
footerStyle = lipgloss.NewStyle().
BorderTop(true).
BorderStyle(lipgloss.NormalBorder())
)
func (m model) Init() tea.Cmd {
return tea.Batch(
textarea.Blink,
m.loadConversation(m.convShortname),
waitForChunk(m.replyChan),
)
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c":
if m.waitingForReply {
m.replyCancelFunc()
} else {
return m, tea.Quit
}
case "q":
if m.focus != focusInput {
return m, tea.Quit
}
default:
var inputHandled tea.Cmd
switch m.focus {
case focusInput:
inputHandled = m.handleInputKey(msg)
case focusMessages:
inputHandled = m.handleMessagesKey(msg)
}
if inputHandled != nil {
return m, inputHandled
}
}
case tea.WindowSizeMsg:
m.content.Width = msg.Width
m.content.Height = msg.Height - m.input.Height() - lipgloss.Height(m.footerView())
m.input.SetWidth(msg.Width - 1)
m.updateContent()
case msgConversationLoaded:
c := (*models.Conversation)(msg)
cmd = m.loadMessages(c)
case msgMessagesLoaded:
m.messages = []models.Message(msg)
m.updateContent()
case msgResponseChunk:
chunk := string(msg)
if len(m.messages) > 0 {
i := len(m.messages) - 1
switch m.messages[i].Role {
case models.MessageRoleAssistant:
m.messages[i].Content += chunk
default:
m.messages = append(m.messages, models.Message{
Role: models.MessageRoleAssistant,
Content: chunk,
})
}
m.updateContent()
}
cmd = waitForChunk(m.replyChan) // wait for the next chunk
case msgResponseEnd:
m.replyCancelFunc = nil
m.waitingForReply = false
m.status = "Press ctrl+s to send"
}
if cmd != nil {
return m, cmd
}
m.input, cmd = m.input.Update(msg)
if cmd != nil {
return m, cmd
}
m.content, cmd = m.content.Update(msg)
if cmd != nil {
return m, cmd
}
return m, cmd
}
func (m model) View() string {
return lipgloss.JoinVertical(
lipgloss.Left,
m.content.View(),
m.inputView(),
m.footerView(),
)
}
func initialModel(ctx *lmcli.Context, convShortname string) model {
m := model{
ctx: ctx,
convShortname: convShortname,
replyChan: make(chan string),
}
m.content = viewport.New(0, 0)
m.input = textarea.New()
m.input.Placeholder = "Enter a message"
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
m.input.ShowLineNumbers = false
m.input.Focus()
m.updateContent()
m.waitingForReply = false
m.status = "Press ctrl+s to send"
return m
}
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "tab":
m.focus = focusInput
m.input.Focus()
}
return nil
}
func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.focus = focusMessages
m.input.Blur()
case "ctrl+s":
userInput := strings.TrimSpace(m.input.Value())
if strings.TrimSpace(userInput) == "" {
return nil
}
m.input.SetValue("")
m.messages = append(m.messages, models.Message{
Role: models.MessageRoleUser,
Content: userInput,
})
m.updateContent()
m.content.GotoBottom()
m.waitingForReply = true
m.status = "Waiting for response, press ctrl+c to cancel..."
return m.promptLLM()
}
return nil
}
func (m *model) loadConversation(shortname string) tea.Cmd {
return func() tea.Msg {
if shortname == "" {
return nil
}
c, err := m.ctx.Store.ConversationByShortName(shortname)
if err != nil {
return msgError(fmt.Errorf("Could not lookup conversation: %v\n", err))
}
if c.ID == 0 {
return msgError(fmt.Errorf("Conversation not found with short name: %s\n", shortname))
}
return msgConversationLoaded(c)
}
}
func (m *model) loadMessages(c *models.Conversation) tea.Cmd {
return func() tea.Msg {
messages, err := m.ctx.Store.Messages(c)
if err != nil {
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
}
return msgMessagesLoaded(messages)
}
}
func waitForChunk(ch chan string) tea.Cmd {
return func() tea.Msg {
return msgResponseChunk(<-ch)
}
}
func (m *model) promptLLM() tea.Cmd {
return func() tea.Msg {
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
if err != nil {
return msgError(err)
}
var toolBag []models.Tool
for _, toolName := range *m.ctx.Config.Tools.EnabledTools {
tool, ok := tools.AvailableTools[toolName]
if ok {
toolBag = append(toolBag, tool)
}
}
requestParams := models.RequestParameters{
Model: *m.ctx.Config.Defaults.Model,
MaxTokens: *m.ctx.Config.Defaults.MaxTokens,
Temperature: *m.ctx.Config.Defaults.Temperature,
ToolBag: toolBag,
}
ctx, replyCancelFunc := context.WithCancel(context.Background())
m.replyCancelFunc = replyCancelFunc
// TODO: supply a reply callback and handle error
resp, _ := completionProvider.CreateChatCompletionStream(
ctx, requestParams, m.messages, nil, m.replyChan,
)
return msgResponseEnd(resp)
}
}
func (m *model) updateContent() {
sb := strings.Builder{}
msgCnt := len(m.messages)
for i, message := range m.messages {
var style lipgloss.Style
if message.Role == models.MessageRoleUser {
style = userStyle
} else {
style = assistantStyle
}
sb.WriteString(fmt.Sprintf("%s:\n\n", style.Render(string(message.Role.FriendlyRole()))))
highlighted, _ := m.ctx.Chroma.HighlightS(message.Content)
sb.WriteString(contentStyle.Width(m.content.Width - 5).Render(highlighted))
if i < msgCnt-1 {
sb.WriteString("\n\n")
}
}
m.content.SetContent(sb.String())
}
func (m model) inputView() string {
var inputView string
if m.waitingForReply {
inputView = inputStyle.Faint(true).Render(m.input.View())
} else {
inputView = inputStyle.Render(m.input.View())
}
return inputView
}
func (m model) footerView() string {
left := m.status
right := fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)
totalWidth := lipgloss.Width(left + right)
var padding string
if m.content.Width-totalWidth > 0 {
padding = strings.Repeat(" ", m.content.Width-totalWidth)
} else {
padding = ""
}
footer := lipgloss.JoinHorizontal(lipgloss.Center, left, padding, right)
return footerStyle.Width(m.content.Width).Render(footer)
}
func Launch(ctx *lmcli.Context, convShortname string) error {
p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen())
if _, err := p.Run(); err != nil {
return fmt.Errorf("Error running program: %v", err)
}
return nil
}