From 6249fbc8f8c4527fc7d64dfea2ec746e573339a7 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Fri, 24 Nov 2023 03:45:43 +0000 Subject: [PATCH] Refactor streamed response handling Update CreateChangeCompletionStream to return the entire response upon stream completion. Renamed HandleDelayedResponse to HandleDelayedContent, which no longer returns the content. Removes the need wrapping HandleDelayedContent in an immediately invoked function and the passing of the completed response over a channel. Also allows us to better handle the case of partial a response. --- pkg/cli/cmd.go | 81 +++++++++++++++++++++++++---------------------- pkg/cli/openai.go | 26 +++++++++------ pkg/cli/tty.go | 16 ++++------ 3 files changed, 66 insertions(+), 57 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 41ea536..2494cfe 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -53,6 +53,21 @@ func SystemPrompt() string { return systemPrompt } +// HandlePartialResponse accepts a response and an err. If err is nil, it does +// nothing and returns nil. If response != "" and err != nil, it prints a +// warning and returns nil. If response == "" and err != nil, it returns the +// error. +func HandlePartialResponse(response string, err error) (e error) { + if err != nil { + if response != "" { + Warn("Received partial response. Error: %v\n", err) + } else { + e = err + } + } + return +} + // InputFromArgsOrEditor returns either the provided input from the args slice // (joined with spaces), or if len(args) is 0, opens an editor and returns // whatever input was provided there. placeholder is a string which populates @@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleDelayedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + assistantReply.OriginalContent = response + fmt.Println() err = store.SaveMessage(&assistantReply) if err != nil { Fatal("Could not save assistant reply: %v\n", err) } - - fmt.Println() }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp @@ -338,7 +350,6 @@ var newCmd = &cobra.Command{ Fatal("No message was provided.\n") } - // TODO: set title if --title provided, otherwise defer for later(?) conversation := Conversation{} err := store.SaveConversation(&conversation) if err != nil { @@ -372,25 +383,22 @@ var newCmd = &cobra.Command{ reply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleDelayedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - reply.OriginalContent = <-response + fmt.Println() + reply.OriginalContent = response err = store.SaveMessage(&reply) if err != nil { Fatal("Could not save reply: %v\n", err) } - fmt.Println() - err = conversation.GenerateTitle() if err != nil { Warn("Could not generate title for conversation: %v\n", err) @@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{ } receiver := make(chan string) - go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + + err = HandlePartialResponse(response, err) if err != nil { Fatal("%v\n", err) } @@ -459,7 +469,7 @@ var retryCmd = &cobra.Command{ } var lastUserMessageIndex int - for i := len(messages) - 1; i >=0; i-- { + for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { lastUserMessageIndex = i break @@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleStreamedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + fmt.Println() + assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { @@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{ }, } - var continueCmd = &cobra.Command{ Use: "continue ", Short: "Continues where the previous prompt left off.", @@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleStreamedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + fmt.Println() + assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 57f1407..3ef2100 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "strings" openai "github.com/sashabaranov/go-openai" ) @@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri } // CreateChatCompletionStream submits a streaming Chat Completion API request -// and streams the response to the provided output channel. -func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { +// and both returns and streams the response to the provided output channel. +// May return a partial response if an error occurs mid-stream. +func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) @@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, stream, err := client.CreateChatCompletionStream(context.Background(), req) if err != nil { - return err + return "", err } - defer stream.Close() + sb := strings.Builder{} for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - return nil + response, e := stream.Recv() + if errors.Is(e, io.EOF) { + break } - if err != nil { - return err + if e != nil { + err = e + break } - output <- response.Choices[0].Delta.Content + chunk := response.Choices[0].Delta.Content + output <- chunk + sb.WriteString(chunk) } + return sb.String(), err } diff --git a/pkg/cli/tty.go b/pkg/cli/tty.go index ad0ed9b..4ab6f38 100644 --- a/pkg/cli/tty.go +++ b/pkg/cli/tty.go @@ -3,7 +3,6 @@ package cli import ( "fmt" "os" - "strings" "time" "github.com/alecthomas/chroma/v2/quick" @@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) { } } -// HandledDelayedResponse writes a waiting animation (abusing \r) and the -// (possibly chunked) content received on the response channel to stdout. +// HandleDelayedContent displays a waiting animation to stdout while waiting +// for content to be received on the provided channel. As soon as any (possibly +// chunked) content is received on the channel, the waiting animation is +// replaced by the content. // Blocks until the channel is closed. -func HandleDelayedResponse(response chan string) string { +func HandleDelayedContent(content <-chan string) { waitSignal := make(chan any) go ShowWaitAnimation(waitSignal) - sb := strings.Builder{} - firstChunk := true - for chunk := range response { + for chunk := range content { if firstChunk { // notify wait animation that we've received data waitSignal <- "" @@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string { firstChunk = false } fmt.Print(chunk) - sb.WriteString(chunk) } - - return sb.String() } // RenderConversation renders the given messages to TTY, with optional space