Refactor streamed response handling

Update CreateChangeCompletionStream to return the entire response upon
stream completion. Renamed HandleDelayedResponse to
HandleDelayedContent, which no longer returns the content.

Removes the need wrapping HandleDelayedContent in an immediately invoked
function and the passing of the completed response over a channel. Also
allows us to better handle the case of partial a response.
This commit is contained in:
Matt Low 2023-11-24 03:45:43 +00:00
parent 303c4193cb
commit 6249fbc8f8
3 changed files with 66 additions and 57 deletions

View File

@ -53,6 +53,21 @@ func SystemPrompt() string {
return systemPrompt return systemPrompt
} }
// HandlePartialResponse accepts a response and an err. If err is nil, it does
// nothing and returns nil. If response != "" and err != nil, it prints a
// warning and returns nil. If response == "" and err != nil, it returns the
// error.
func HandlePartialResponse(response string, err error) (e error) {
if err != nil {
if response != "" {
Warn("Received partial response. Error: %v\n", err)
} else {
e = err
}
}
return
}
// InputFromArgsOrEditor returns either the provided input from the args slice // InputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns // (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates // whatever input was provided there. placeholder is a string which populates
@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) go HandleDelayedContent(receiver)
go func() { response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
response <- HandleDelayedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver) err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("Error while receiving response: %v\n", err)
} }
assistantReply.OriginalContent = <-response assistantReply.OriginalContent = response
fmt.Println()
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Fatal("Could not save assistant reply: %v\n", err)
} }
fmt.Println()
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -338,7 +350,6 @@ var newCmd = &cobra.Command{
Fatal("No message was provided.\n") Fatal("No message was provided.\n")
} }
// TODO: set title if --title provided, otherwise defer for later(?)
conversation := Conversation{} conversation := Conversation{}
err := store.SaveConversation(&conversation) err := store.SaveConversation(&conversation)
if err != nil { if err != nil {
@ -372,25 +383,22 @@ var newCmd = &cobra.Command{
reply.RenderTTY() reply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) go HandleDelayedContent(receiver)
go func() { response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
response <- HandleDelayedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver) err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("Error while receiving response: %v\n", err)
} }
reply.OriginalContent = <-response fmt.Println()
reply.OriginalContent = response
err = store.SaveMessage(&reply) err = store.SaveMessage(&reply)
if err != nil { if err != nil {
Fatal("Could not save reply: %v\n", err) Fatal("Could not save reply: %v\n", err)
} }
fmt.Println()
err = conversation.GenerateTitle() err = conversation.GenerateTitle()
if err != nil { if err != nil {
Warn("Could not generate title for conversation: %v\n", err) Warn("Could not generate title for conversation: %v\n", err)
@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{
} }
receiver := make(chan string) receiver := make(chan string)
go HandleDelayedResponse(receiver) go HandleDelayedContent(receiver)
err := CreateChatCompletionStream(model, messages, maxTokens, receiver) response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }
@ -459,7 +469,7 @@ var retryCmd = &cobra.Command{
} }
var lastUserMessageIndex int var lastUserMessageIndex int
for i := len(messages) - 1; i >=0; i-- { for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" { if messages[i].Role == "user" {
lastUserMessageIndex = i lastUserMessageIndex = i
break break
@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) go HandleDelayedContent(receiver)
go func() { response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
response <- HandleStreamedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver) err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("Error while receiving response: %v\n", err)
} }
assistantReply.OriginalContent = <-response fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {
@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{
}, },
} }
var continueCmd = &cobra.Command{ var continueCmd = &cobra.Command{
Use: "continue <conversation>", Use: "continue <conversation>",
Short: "Continues where the previous prompt left off.", Short: "Continues where the previous prompt left off.",
@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) go HandleDelayedContent(receiver)
go func() { response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
response <- HandleStreamedResponse(receiver)
}()
err = CreateChatCompletionStream(model, messages, maxTokens, receiver) err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("Error while receiving response: %v\n", err)
} }
assistantReply.OriginalContent = <-response fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"strings"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
} }
// CreateChatCompletionStream submits a streaming Chat Completion API request // CreateChatCompletionStream submits a streaming Chat Completion API request
// and streams the response to the provided output channel. // and both returns and streams the response to the provided output channel.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { // May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
stream, err := client.CreateChatCompletionStream(context.Background(), req) stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil { if err != nil {
return err return "", err
} }
defer stream.Close() defer stream.Close()
sb := strings.Builder{}
for { for {
response, err := stream.Recv() response, e := stream.Recv()
if errors.Is(err, io.EOF) { if errors.Is(e, io.EOF) {
return nil break
} }
if err != nil { if e != nil {
return err err = e
break
} }
output <- response.Choices[0].Delta.Content chunk := response.Choices[0].Delta.Content
output <- chunk
sb.WriteString(chunk)
} }
return sb.String(), err
} }

View File

@ -3,7 +3,6 @@ package cli
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"time" "time"
"github.com/alecthomas/chroma/v2/quick" "github.com/alecthomas/chroma/v2/quick"
@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) {
} }
} }
// HandledDelayedResponse writes a waiting animation (abusing \r) and the // HandleDelayedContent displays a waiting animation to stdout while waiting
// (possibly chunked) content received on the response channel to stdout. // for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func HandleDelayedResponse(response chan string) string { func HandleDelayedContent(content <-chan string) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)
sb := strings.Builder{}
firstChunk := true firstChunk := true
for chunk := range response { for chunk := range content {
if firstChunk { if firstChunk {
// notify wait animation that we've received data // notify wait animation that we've received data
waitSignal <- "" waitSignal <- ""
@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string {
firstChunk = false firstChunk = false
} }
fmt.Print(chunk) fmt.Print(chunk)
sb.WriteString(chunk)
} }
return sb.String()
} }
// RenderConversation renders the given messages to TTY, with optional space // RenderConversation renders the given messages to TTY, with optional space