From c02b21ca3705dc7cb5d785a5eb88eefd3b268355 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Fri, 24 Nov 2023 15:17:24 +0000 Subject: [PATCH] Refactor the last refactor :) Removed HandlePartialResponse, add LLMRequest which handles all common logic of making LLM requests and returning/showing their response. --- pkg/cli/cmd.go | 85 +++++++++++++++++------------------------------ pkg/cli/openai.go | 2 -- 2 files changed, 31 insertions(+), 56 deletions(-) diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 2494cfe..befb794 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -53,19 +53,28 @@ func SystemPrompt() string { return systemPrompt } -// HandlePartialResponse accepts a response and an err. If err is nil, it does -// nothing and returns nil. If response != "" and err != nil, it prints a -// warning and returns nil. If response == "" and err != nil, it returns the -// error. -func HandlePartialResponse(response string, err error) (e error) { - if err != nil { - if response != "" { - Warn("Received partial response. Error: %v\n", err) - } else { - e = err - } +// LLMRequest prompts the LLM with the given Message, writes the (partial) +// response to stdout, and returns the (partial) response or any errors. +func LLMRequest(messages []Message) (string, error) { + // receiver receives the reponse from LLM + receiver := make(chan string) + defer close(receiver) + + // start HandleDelayedContent goroutine to print received data to stdout + go HandleDelayedContent(receiver) + + response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) + if err != nil && response != "" { + Warn("Received partial response. Error: %v\n", err) + err = nil // ignore partial response error } - return + + if response != "" { + // there was some content, so break to a new line after it + fmt.Println() + } + + return response, err } // InputFromArgsOrEditor returns either the provided input from the args slice @@ -314,17 +323,12 @@ var replyCmd = &cobra.Command{ } assistantReply.RenderTTY() - receiver := make(chan string) - go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - - err = HandlePartialResponse(response, err) + response, err := LLMRequest(messages) if err != nil { - Fatal("Error while receiving response: %v\n", err) + Fatal("Error fetching LLM response: %v\n", err) } assistantReply.OriginalContent = response - fmt.Println() err = store.SaveMessage(&assistantReply) if err != nil { @@ -382,16 +386,11 @@ var newCmd = &cobra.Command{ } reply.RenderTTY() - receiver := make(chan string) - go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - - err = HandlePartialResponse(response, err) + response, err := LLMRequest(messages) if err != nil { - Fatal("Error while receiving response: %v\n", err) + Fatal("Error fetching LLM response: %v\n", err) } - fmt.Println() reply.OriginalContent = response err = store.SaveMessage(&reply) @@ -432,16 +431,10 @@ var promptCmd = &cobra.Command{ }, } - receiver := make(chan string) - go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - - err = HandlePartialResponse(response, err) + _, err := LLMRequest(messages) if err != nil { - Fatal("%v\n", err) + Fatal("Error fetching LLM response: %v\n", err) } - - fmt.Println() }, } @@ -485,25 +478,17 @@ var retryCmd = &cobra.Command{ } assistantReply.RenderTTY() - receiver := make(chan string) - go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - - err = HandlePartialResponse(response, err) + response, err := LLMRequest(messages) if err != nil { - Fatal("Error while receiving response: %v\n", err) + Fatal("Error fetching LLM response: %v\n", err) } - fmt.Println() assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { Fatal("Could not save assistant reply: %v\n", err) } - - fmt.Println() - }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp @@ -544,25 +529,17 @@ var continueCmd = &cobra.Command{ } assistantReply.RenderTTY() - receiver := make(chan string) - go HandleDelayedContent(receiver) - response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver) - - err = HandlePartialResponse(response, err) + response, err := LLMRequest(messages) if err != nil { - Fatal("Error while receiving response: %v\n", err) + Fatal("Error fetching LLM response: %v\n", err) } - fmt.Println() assistantReply.OriginalContent = response err = store.SaveMessage(&assistantReply) if err != nil { Fatal("Could not save assistant reply: %v\n", err) } - - fmt.Println() - }, ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { compMode := cobra.ShellCompDirectiveNoFileComp diff --git a/pkg/cli/openai.go b/pkg/cli/openai.go index 3ef2100..ae7bad8 100644 --- a/pkg/cli/openai.go +++ b/pkg/cli/openai.go @@ -45,8 +45,6 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int, client := openai.NewClient(*config.OpenAI.APIKey) req := CreateChatCompletionRequest(model, messages, maxTokens) - defer close(output) - stream, err := client.CreateChatCompletionStream(context.Background(), req) if err != nil { return "", err