From 68f986dc06a803a7a2ce6f5aef3bc1a03784cb3f Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 30 Oct 2023 21:45:21 +0000 Subject: [PATCH] Use the streamed response API --- main.go | 6 +++--- openai.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/main.go b/main.go index 8429223..af57fb3 100644 --- a/main.go +++ b/main.go @@ -107,13 +107,13 @@ var newCmd = &cobra.Command{ }, } - response, err := CreateChatCompletion("You are a helpful assistant.", messages) + err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout) if err != nil { - fmt.Fprintf(os.Stderr, "Error getting chat response: %v\n", err) + fmt.Fprintf(os.Stderr, "An error occured: %v\n", err) os.Exit(1) } - fmt.Println(response); + fmt.Println() }, } diff --git a/openai.go b/openai.go index 9b76372..476b174 100644 --- a/openai.go +++ b/openai.go @@ -2,7 +2,9 @@ package main import ( "context" + "errors" "fmt" + "io" "os" openai "github.com/sashabaranov/go-openai" ) @@ -34,3 +36,45 @@ func CreateChatCompletion(system string, messages []Message) (string, error) { return resp.Choices[0].Message.Content, nil } + +func CreateChatCompletionStream(system string, messages []Message, output io.Writer) (error) { + client := openai.NewClient(os.Getenv("OPENAI_APIKEY")) + ctx := context.Background() + + + var chatCompletionMessages []openai.ChatCompletionMessage + for _, m := range(messages) { + chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ + Role: m.Role, + Content: m.OriginalContent, + }) + } + + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + MaxTokens: 20, + Messages: chatCompletionMessages, + Stream: true, + } + + stream, err := client.CreateChatCompletionStream(ctx, req) + if err != nil { + return err + } + + defer stream.Close() + + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + return nil + } + + if err != nil { + //fmt.Printf("\nStream error: %v\n", err) + return err + } + + fmt.Fprint(output, response.Choices[0].Delta.Content) + } +}