Refactor streamed response handling

Update CreateChangeCompletionStream to return the entire response upon
stream completion. Renamed HandleDelayedResponse to
HandleDelayedContent, which no longer returns the content.

Removes the need wrapping HandleDelayedContent in an immediately invoked
function and the passing of the completed response over a channel. Also
allows us to better handle the case of partial a response.
This commit is contained in:
Matt Low 2023-11-24 03:45:43 +00:00
parent 303c4193cb
commit 6249fbc8f8
3 changed files with 66 additions and 57 deletions

View File

@ -53,6 +53,21 @@ func SystemPrompt() string {
return systemPrompt
}
// HandlePartialResponse accepts a response and an err. If err is nil, it does
// nothing and returns nil. If response != "" and err != nil, it prints a
// warning and returns nil. If response == "" and err != nil, it returns the
// error.
func HandlePartialResponse(response string, err error) (e error) {
if err != nil {
if response != "" {
Warn("Received partial response. Error: %v\n", err)
} else {
e = err
}
}
return
}
// InputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates
@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{
assistantReply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleDelayedResponse(receiver)
}()
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil {
Fatal("%v\n", err)
Fatal("Error while receiving response: %v\n", err)
}
assistantReply.OriginalContent = <-response
assistantReply.OriginalContent = response
fmt.Println()
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
fmt.Println()
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
@ -338,7 +350,6 @@ var newCmd = &cobra.Command{
Fatal("No message was provided.\n")
}
// TODO: set title if --title provided, otherwise defer for later(?)
conversation := Conversation{}
err := store.SaveConversation(&conversation)
if err != nil {
@ -372,25 +383,22 @@ var newCmd = &cobra.Command{
reply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleDelayedResponse(receiver)
}()
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil {
Fatal("%v\n", err)
Fatal("Error while receiving response: %v\n", err)
}
reply.OriginalContent = <-response
fmt.Println()
reply.OriginalContent = response
err = store.SaveMessage(&reply)
if err != nil {
Fatal("Could not save reply: %v\n", err)
}
fmt.Println()
err = conversation.GenerateTitle()
if err != nil {
Warn("Could not generate title for conversation: %v\n", err)
@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{
}
receiver := make(chan string)
go HandleDelayedResponse(receiver)
err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil {
Fatal("%v\n", err)
}
@ -459,7 +469,7 @@ var retryCmd = &cobra.Command{
}
var lastUserMessageIndex int
for i := len(messages) - 1; i >=0; i-- {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
lastUserMessageIndex = i
break
@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{
assistantReply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleStreamedResponse(receiver)
}()
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil {
Fatal("%v\n", err)
Fatal("Error while receiving response: %v\n", err)
}
assistantReply.OriginalContent = <-response
fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{
},
}
var continueCmd = &cobra.Command{
Use: "continue <conversation>",
Short: "Continues where the previous prompt left off.",
@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{
assistantReply.RenderTTY()
receiver := make(chan string)
response := make(chan string)
go func() {
response <- HandleStreamedResponse(receiver)
}()
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil {
Fatal("%v\n", err)
Fatal("Error while receiving response: %v\n", err)
}
assistantReply.OriginalContent = <-response
fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"io"
"strings"
openai "github.com/sashabaranov/go-openai"
)
@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
}
// CreateChatCompletionStream submits a streaming Chat Completion API request
// and streams the response to the provided output channel.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
// and both returns and streams the response to the provided output channel.
// May return a partial response if an error occurs mid-stream.
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return err
return "", err
}
defer stream.Close()
sb := strings.Builder{}
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
return nil
response, e := stream.Recv()
if errors.Is(e, io.EOF) {
break
}
if err != nil {
return err
if e != nil {
err = e
break
}
output <- response.Choices[0].Delta.Content
chunk := response.Choices[0].Delta.Content
output <- chunk
sb.WriteString(chunk)
}
return sb.String(), err
}

View File

@ -3,7 +3,6 @@ package cli
import (
"fmt"
"os"
"strings"
"time"
"github.com/alecthomas/chroma/v2/quick"
@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) {
}
}
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
// (possibly chunked) content received on the response channel to stdout.
// HandleDelayedContent displays a waiting animation to stdout while waiting
// for content to be received on the provided channel. As soon as any (possibly
// chunked) content is received on the channel, the waiting animation is
// replaced by the content.
// Blocks until the channel is closed.
func HandleDelayedResponse(response chan string) string {
func HandleDelayedContent(content <-chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
sb := strings.Builder{}
firstChunk := true
for chunk := range response {
for chunk := range content {
if firstChunk {
// notify wait animation that we've received data
waitSignal <- ""
@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string {
firstChunk = false
}
fmt.Print(chunk)
sb.WriteString(chunk)
}
return sb.String()
}
// RenderConversation renders the given messages to TTY, with optional space