Refactor streamed response handling
Update CreateChangeCompletionStream to return the entire response upon stream completion. Renamed HandleDelayedResponse to HandleDelayedContent, which no longer returns the content. Removes the need wrapping HandleDelayedContent in an immediately invoked function and the passing of the completed response over a channel. Also allows us to better handle the case of partial a response.
This commit is contained in:
parent
303c4193cb
commit
6249fbc8f8
@ -53,6 +53,21 @@ func SystemPrompt() string {
|
|||||||
return systemPrompt
|
return systemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HandlePartialResponse accepts a response and an err. If err is nil, it does
|
||||||
|
// nothing and returns nil. If response != "" and err != nil, it prints a
|
||||||
|
// warning and returns nil. If response == "" and err != nil, it returns the
|
||||||
|
// error.
|
||||||
|
func HandlePartialResponse(response string, err error) (e error) {
|
||||||
|
if err != nil {
|
||||||
|
if response != "" {
|
||||||
|
Warn("Received partial response. Error: %v\n", err)
|
||||||
|
} else {
|
||||||
|
e = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||||
// whatever input was provided there. placeholder is a string which populates
|
// whatever input was provided there. placeholder is a string which populates
|
||||||
@ -300,24 +315,21 @@ var replyCmd = &cobra.Command{
|
|||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
response := make(chan string)
|
go HandleDelayedContent(receiver)
|
||||||
go func() {
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
response <- HandleDelayedResponse(receiver)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
err = HandlePartialResponse(response, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("Error while receiving response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = <-response
|
assistantReply.OriginalContent = response
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Fatal("Could not save assistant reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
@ -338,7 +350,6 @@ var newCmd = &cobra.Command{
|
|||||||
Fatal("No message was provided.\n")
|
Fatal("No message was provided.\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: set title if --title provided, otherwise defer for later(?)
|
|
||||||
conversation := Conversation{}
|
conversation := Conversation{}
|
||||||
err := store.SaveConversation(&conversation)
|
err := store.SaveConversation(&conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -372,25 +383,22 @@ var newCmd = &cobra.Command{
|
|||||||
reply.RenderTTY()
|
reply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
response := make(chan string)
|
go HandleDelayedContent(receiver)
|
||||||
go func() {
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
response <- HandleDelayedResponse(receiver)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
err = HandlePartialResponse(response, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("Error while receiving response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
reply.OriginalContent = <-response
|
fmt.Println()
|
||||||
|
reply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&reply)
|
err = store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save reply: %v\n", err)
|
Fatal("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
err = conversation.GenerateTitle()
|
err = conversation.GenerateTitle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Warn("Could not generate title for conversation: %v\n", err)
|
Warn("Could not generate title for conversation: %v\n", err)
|
||||||
@ -425,8 +433,10 @@ var promptCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
go HandleDelayedResponse(receiver)
|
go HandleDelayedContent(receiver)
|
||||||
err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
|
|
||||||
|
err = HandlePartialResponse(response, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("%v\n", err)
|
||||||
}
|
}
|
||||||
@ -476,17 +486,16 @@ var retryCmd = &cobra.Command{
|
|||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
response := make(chan string)
|
go HandleDelayedContent(receiver)
|
||||||
go func() {
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
response <- HandleStreamedResponse(receiver)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
err = HandlePartialResponse(response, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("Error while receiving response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = <-response
|
fmt.Println()
|
||||||
|
assistantReply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -505,7 +514,6 @@ var retryCmd = &cobra.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var continueCmd = &cobra.Command{
|
var continueCmd = &cobra.Command{
|
||||||
Use: "continue <conversation>",
|
Use: "continue <conversation>",
|
||||||
Short: "Continues where the previous prompt left off.",
|
Short: "Continues where the previous prompt left off.",
|
||||||
@ -537,17 +545,16 @@ var continueCmd = &cobra.Command{
|
|||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
receiver := make(chan string)
|
||||||
response := make(chan string)
|
go HandleDelayedContent(receiver)
|
||||||
go func() {
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
response <- HandleStreamedResponse(receiver)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
err = HandlePartialResponse(response, err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("Error while receiving response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = <-response
|
fmt.Println()
|
||||||
|
assistantReply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
@ -38,8 +39,9 @@ func CreateChatCompletion(model string, messages []Message, maxTokens int) (stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||||
// and streams the response to the provided output channel.
|
// and both returns and streams the response to the provided output channel.
|
||||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan string) error {
|
// May return a partial response if an error occurs mid-stream.
|
||||||
|
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
|
|
||||||
@ -47,20 +49,24 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer stream.Close()
|
defer stream.Close()
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
for {
|
for {
|
||||||
response, err := stream.Recv()
|
response, e := stream.Recv()
|
||||||
if errors.Is(err, io.EOF) {
|
if errors.Is(e, io.EOF) {
|
||||||
return nil
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if e != nil {
|
||||||
return err
|
err = e
|
||||||
|
break
|
||||||
}
|
}
|
||||||
output <- response.Choices[0].Delta.Content
|
chunk := response.Choices[0].Delta.Content
|
||||||
|
output <- chunk
|
||||||
|
sb.WriteString(chunk)
|
||||||
}
|
}
|
||||||
|
return sb.String(), err
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,6 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alecthomas/chroma/v2/quick"
|
"github.com/alecthomas/chroma/v2/quick"
|
||||||
@ -37,17 +36,17 @@ func ShowWaitAnimation(signal chan any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
|
// HandleDelayedContent displays a waiting animation to stdout while waiting
|
||||||
// (possibly chunked) content received on the response channel to stdout.
|
// for content to be received on the provided channel. As soon as any (possibly
|
||||||
|
// chunked) content is received on the channel, the waiting animation is
|
||||||
|
// replaced by the content.
|
||||||
// Blocks until the channel is closed.
|
// Blocks until the channel is closed.
|
||||||
func HandleDelayedResponse(response chan string) string {
|
func HandleDelayedContent(content <-chan string) {
|
||||||
waitSignal := make(chan any)
|
waitSignal := make(chan any)
|
||||||
go ShowWaitAnimation(waitSignal)
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
sb := strings.Builder{}
|
|
||||||
|
|
||||||
firstChunk := true
|
firstChunk := true
|
||||||
for chunk := range response {
|
for chunk := range content {
|
||||||
if firstChunk {
|
if firstChunk {
|
||||||
// notify wait animation that we've received data
|
// notify wait animation that we've received data
|
||||||
waitSignal <- ""
|
waitSignal <- ""
|
||||||
@ -56,10 +55,7 @@ func HandleDelayedResponse(response chan string) string {
|
|||||||
firstChunk = false
|
firstChunk = false
|
||||||
}
|
}
|
||||||
fmt.Print(chunk)
|
fmt.Print(chunk)
|
||||||
sb.WriteString(chunk)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenderConversation renders the given messages to TTY, with optional space
|
// RenderConversation renders the given messages to TTY, with optional space
|
||||||
|
Loading…
Reference in New Issue
Block a user