Refactor streamed response handling
Update CreateChangeCompletionStream to return the entire response upon stream completion. Renamed HandleDelayedResponse to HandleDelayedContent, which no longer returns the content. Removes the need wrapping HandleDelayedContent in an immediately invoked function and the passing of the completed response over a channel. Also allows us to better handle the case of partial a response.
This commit is contained in:
parent
303c4193cb
commit
6249fbc8f8
@ -53,6 +53,21 @@ func SystemPrompt() string {
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// HandlePartialResponse accepts a response and an err. If err is nil, it does
|
||||
// nothing and returns nil. If response != "" and err != nil, it prints a
|
||||
// warning and returns nil. If response == "" and err != nil, it returns the
|
||||
// error.
|
||||
func HandlePartialResponse(response string, err error) (e error) {
|
||||
if err != nil {
|
||||
if response != "" {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
} else {
|
||||
e = err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
go func() {
|
||||
response <- HandleDelayedResponse(receiver)
|
||||
}()
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
err = HandlePartialResponse(response, err)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = <-response
|
||||
assistantReply.OriginalContent = response
|
||||
fmt.Println()
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
@ -338,7 +350,6 @@ var newCmd = &cobra.Command{
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
// TODO: set title if --title provided, otherwise defer for later(?)
|
||||
conversation := Conversation{}
|
||||
err := store.SaveConversation(&conversation)
|
||||
if err != nil {
|
||||
@ -372,25 +383,22 @@ var newCmd = &cobra.Command{
|
||||
reply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
go func() {
|
||||
response <- HandleDelayedResponse(receiver)
|
||||
}()
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
err = HandlePartialResponse(response, err)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
}
|
||||
|
||||
reply.OriginalContent = <-response
|
||||
fmt.Println()
|
||||
reply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Fatal("Could not save reply: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
err = conversation.GenerateTitle()
|
||||
if err != nil {
|
||||
Warn("Could not generate title for conversation: %v\n", err)
|
||||
@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
@ -459,7 +469,7 @@ var retryCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
var lastUserMessageIndex int
|
||||
for i := len(messages) - 1; i >=0; i-- {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
lastUserMessageIndex = i
|
||||
break
|
||||
@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
go func() {
|
||||
response <- HandleStreamedResponse(receiver)
|
||||
}()
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
err = HandlePartialResponse(response, err)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = <-response
|
||||
fmt.Println()
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
var continueCmd = &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continues where the previous prompt left off.",
|
||||
@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
go func() {
|
||||
response <- HandleStreamedResponse(receiver)
|
||||
}()
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
err = HandlePartialResponse(response, err)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = <-response
|
||||
fmt.Println()
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
// and streams the response to the provided output channel.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
|
||||
// and both returns and streams the response to the provided output channel.
|
||||
// May return a partial response if an error occurs mid-stream.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
|
||||
sb := strings.Builder{}
|
||||
for {
|
||||
response, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
output <- response.Choices[0].Delta.Content
|
||||
chunk := response.Choices[0].Delta.Content
|
||||
output <- chunk
|
||||
sb.WriteString(chunk)
|
||||
}
|
||||
return sb.String(), err
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package cli
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alecthomas/chroma/v2/quick"
|
||||
@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) {
|
||||
}
|
||||
}
|
||||
|
||||
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
|
||||
// (possibly chunked) content received on the response channel to stdout.
|
||||
// HandleDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func HandleDelayedResponse(response chan string) string {
|
||||
func HandleDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
sb := strings.Builder{}
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range response {
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string {
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
sb.WriteString(chunk)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
|
Loading…
Reference in New Issue
Block a user