tui: fixed response cancelling

This commit is contained in:
Matt Low 2024-03-15 06:44:42 +00:00
parent 6242ea17d8
commit e9fde37201

View File

@ -44,9 +44,9 @@ type model struct {
conversation *models.Conversation conversation *models.Conversation
messages []models.Message messages []models.Message
waitingForReply bool waitingForReply bool
stopSignal chan interface{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan string replyChunkChan chan string
replyCancelFunc context.CancelFunc
err error err error
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
@ -124,7 +124,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg.String() { switch msg.String() {
case "ctrl+c": case "ctrl+c":
if m.waitingForReply { if m.waitingForReply {
m.replyCancelFunc() m.stopSignal <- "stahp!"
} else { } else {
return m, tea.Quit return m, tea.Quit
} }
@ -204,11 +204,9 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.updateContent() m.updateContent()
cmds = append(cmds, m.waitForReply()) cmds = append(cmds, m.waitForReply())
case msgResponseEnd: case msgResponseEnd:
m.replyCancelFunc = nil
m.waitingForReply = false m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
case msgResponseError: case msgResponseError:
m.replyCancelFunc = nil
m.waitingForReply = false m.waitingForReply = false
m.status = "Press ctrl+s to send" m.status = "Press ctrl+s to send"
m.err = error(msg) m.err = error(msg)
@ -375,6 +373,7 @@ func initialModel(ctx *lmcli.Context, convShortname string) model {
conversation: &models.Conversation{}, conversation: &models.Conversation{},
persistence: true, persistence: true,
stopSignal: make(chan interface{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan string), replyChunkChan: make(chan string),
} }
@ -551,16 +550,25 @@ func (m *model) promptLLM() tea.Cmd {
m.replyChan <- msg m.replyChan <- msg
} }
ctx, replyCancelFunc := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
m.replyCancelFunc = replyCancelFunc
canceled := false
go func() {
select {
case <-m.stopSignal:
canceled = true
cancel()
}
}()
resp, err := completionProvider.CreateChatCompletionStream( resp, err := completionProvider.CreateChatCompletionStream(
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan, ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
) )
if err != nil { if err != nil && !canceled {
return msgResponseError(err) return msgResponseError(err)
} }
return msgResponseEnd(resp) return msgResponseEnd(resp)
} }
} }