diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 41ea536..2494cfe 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -53,6 +53,21 @@ func SystemPrompt() string { return systemPrompt } +// HandlePartialResponse accepts a response and an err. If err is nil, it does +// nothing and returns nil. If response != "" and err != nil, it prints a +// warning and returns nil. If response == "" and err != nil, it returns the +// error. +func HandlePartialResponse(response string, err error) (e error) { + if err != nil { + if response != "" { + Warn("Received partial response. Error: %v\n", err) + } else { + e = err + } + } + return +} + // InputFromArgsOrEditor returns either the provided input from the args slice // (joined with spaces), or if len(args) is 0, opens an editor and returns // whatever input was provided there. placeholder is a string which populates @@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleDelayedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + assistantReply.OriginalContent = response + fmt.Println() err = store.SaveMessage(&assistantReply) if err != nil { Fatal("Could not save assistant reply: %v\n", err) } - - fmt.Println() }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp @@ -338,7 +350,6 @@ var newCmd = &cobra.Command{ Fatal("No message was provided.\n") } - // TODO: set title if --title provided, otherwise defer for later(?) conversation := Conversation{} err := store.SaveConversation(&conversation) if err != nil { @@ -372,25 +383,22 @@ var newCmd = &cobra.Command{ reply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleDelayedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - reply.OriginalContent = <-response + fmt.Println() + reply.OriginalContent = response err = store.SaveMessage(&reply) if err != nil { Fatal("Could not save reply: %v\n", err) } - fmt.Println() - err = conversation.GenerateTitle() if err != nil { Warn("Could not generate title for conversation: %v\n", err) @@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{ } receiver := make(chan string) - go HandleDelayedResponse(receiver) - err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + + err = HandlePartialResponse(response, err) if err != nil { Fatal("%v\n", err) } @@ -459,7 +469,7 @@ var retryCmd = &cobra.Command{ } var lastUserMessageIndex int - for i := len(messages) - 1; i >=0; i-- { + for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { lastUserMessageIndex = i break @@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleStreamedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + fmt.Println() + assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { @@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{ }, } - var continueCmd = &cobra.Command{ Use: "continue ", Short: "Continues where the previous prompt left off.", @@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{ assistantReply.RenderTTY() receiver := make(chan string) - response := make(chan string) - go func() { - response <- HandleStreamedResponse(receiver) - }() + go HandleDelayedContent(receiver) + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - err = CreateChatCompletionStream(model, messages, maxTokens, receiver) + err = HandlePartialResponse(response, err) if err != nil { - Fatal("%v\n", err) + Fatal("Error while receiving response: %v\n", err) } - assistantReply.OriginalContent = <-response + fmt.Println() + assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 57f1407..3ef2100 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "strings" openai "github.com/sashabaranov/go-openai" ) @@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri } // CreateChatCompletionStream submits a streaming Chat Completion API request -// and streams the response to the provided output channel. -func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error { +// and both returns and streams the response to the provided output channel. +// May return a partial response if an error occurs mid-stream. +func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) { client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) @@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, stream, err := client.CreateChatCompletionStream(context.Background(), req) if err != nil { - return err + return "", err } - defer stream.Close() + sb := strings.Builder{} for { - response, err := stream.Recv() - if errors.Is(err, io.EOF) { - return nil + response, e := stream.Recv() + if errors.Is(e, io.EOF) { + break } - if err != nil { - return err + if e != nil { + err = e + break } - output <- response.Choices[0].Delta.Content + chunk := response.Choices[0].Delta.Content + output <- chunk + sb.WriteString(chunk) } + return sb.String(), err } diff --git a/pkg/cli/tty.go b/pkg/cli/tty.go index ad0ed9b..4ab6f38 100644 --- a/pkg/cli/tty.go +++ b/pkg/cli/tty.go @@ -3,7 +3,6 @@ package cli import ( "fmt" "os" - "strings" "time" "github.com/alecthomas/chroma/v2/quick" @@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) { } } -// HandledDelayedResponse writes a waiting animation (abusing \r) and the -// (possibly chunked) content received on the response channel to stdout. +// HandleDelayedContent displays a waiting animation to stdout while waiting +// for content to be received on the provided channel. As soon as any (possibly +// chunked) content is received on the channel, the waiting animation is +// replaced by the content. // Blocks until the channel is closed. -func HandleDelayedResponse(response chan string) string { +func HandleDelayedContent(content <-chan string) { waitSignal := make(chan any) go ShowWaitAnimation(waitSignal) - sb := strings.Builder{} - firstChunk := true - for chunk := range response { + for chunk := range content { if firstChunk { // notify wait animation that we've received data waitSignal <- "" @@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string { firstChunk = false } fmt.Print(chunk) - sb.WriteString(chunk) } - - return sb.String() } // RenderConversation renders the given messages to TTY, with optional space