Refactor the last refactor :)

Removed HandlePartialResponse, add LLMRequest which handles all common
logic of making LLM requests and returning/showing their response.
This commit is contained in:
Matt Low 2023-11-24 15:17:24 +00:00
parent 6249fbc8f8
commit c02b21ca37
2 changed files with 31 additions and 56 deletions

View File

@ -53,19 +53,28 @@ func SystemPrompt() string {
return systemPrompt
}
// HandlePartialResponse accepts a response and an err. If err is nil, it does
// nothing and returns nil. If response != "" and err != nil, it prints a
// warning and returns nil. If response == "" and err != nil, it returns the
// error.
func HandlePartialResponse(response string, err error) (e error) {
if err != nil {
if response != "" {
// LLMRequest prompts the LLM with the given Message, writes the (partial)
// response to stdout, and returns the (partial) response or any errors.
func LLMRequest(messages []Message) (string, error) {
// receiver receives the reponse from LLM
receiver := make(chan string)
defer close(receiver)
// start HandleDelayedContent goroutine to print received data to stdout
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil && response != "" {
Warn("Received partial response. Error: %v\n", err)
} else {
e = err
err = nil // ignore partial response error
}
if response != "" {
// there was some content, so break to a new line after it
fmt.Println()
}
return
return response, err
}
// InputFromArgsOrEditor returns either the provided input from the args slice
@ -314,17 +323,12 @@ var replyCmd = &cobra.Command{
}
assistantReply.RenderTTY()
receiver := make(chan string)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error while receiving response: %v\n", err)
Fatal("Error fetching LLM response: %v\n", err)
}
assistantReply.OriginalContent = response
fmt.Println()
err = store.SaveMessage(&assistantReply)
if err != nil {
@ -382,16 +386,11 @@ var newCmd = &cobra.Command{
}
reply.RenderTTY()
receiver := make(chan string)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error while receiving response: %v\n", err)
Fatal("Error fetching LLM response: %v\n", err)
}
fmt.Println()
reply.OriginalContent = response
err = store.SaveMessage(&reply)
@ -432,16 +431,10 @@ var promptCmd = &cobra.Command{
},
}
receiver := make(chan string)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
_, err := LLMRequest(messages)
if err != nil {
Fatal("%v\n", err)
Fatal("Error fetching LLM response: %v\n", err)
}
fmt.Println()
},
}
@ -485,25 +478,17 @@ var retryCmd = &cobra.Command{
}
assistantReply.RenderTTY()
receiver := make(chan string)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error while receiving response: %v\n", err)
Fatal("Error fetching LLM response: %v\n", err)
}
fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
fmt.Println()
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp
@ -544,25 +529,17 @@ var continueCmd = &cobra.Command{
}
assistantReply.RenderTTY()
receiver := make(chan string)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
response, err := LLMRequest(messages)
if err != nil {
Fatal("Error while receiving response: %v\n", err)
Fatal("Error fetching LLM response: %v\n", err)
}
fmt.Println()
assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply)
if err != nil {
Fatal("Could not save assistant reply: %v\n", err)
}
fmt.Println()
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp

View File

@ -45,8 +45,6 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens)
defer close(output)
stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return "", err