Refactor the last refactor :)

Removed HandlePartialResponse, add LLMRequest which handles all common
logic of making LLM requests and returning/showing their response.
This commit is contained in:
Matt Low 2023-11-24 15:17:24 +00:00
parent 6249fbc8f8
commit c02b21ca37
2 changed files with 31 additions and 56 deletions

View File

@ -53,19 +53,28 @@ func SystemPrompt() string {
return systemPrompt return systemPrompt
} }
// HandlePartialResponse accepts a response and an err. If err is nil, it does // LLMRequest prompts the LLM with the given Message, writes the (partial)
// nothing and returns nil. If response != "" and err != nil, it prints a // response to stdout, and returns the (partial) response or any errors.
// warning and returns nil. If response == "" and err != nil, it returns the func LLMRequest(messages []Message) (string, error) {
// error. // receiver receives the reponse from LLM
func HandlePartialResponse(response string, err error) (e error) { receiver := make(chan string)
if err != nil { defer close(receiver)
if response != "" {
Warn("Received partial response. Error: %v\n", err) // start HandleDelayedContent goroutine to print received data to stdout
} else { go HandleDelayedContent(receiver)
e = err
} response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil && response != "" {
Warn("Received partial response. Error: %v\n", err)
err = nil // ignore partial response error
} }
return
if response != "" {
// there was some content, so break to a new line after it
fmt.Println()
}
return response, err
} }
// InputFromArgsOrEditor returns either the provided input from the args slice // InputFromArgsOrEditor returns either the provided input from the args slice
@ -314,17 +323,12 @@ var replyCmd = &cobra.Command{
} }
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) response, err := LLMRequest(messages)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("Error while receiving response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
assistantReply.OriginalContent = response assistantReply.OriginalContent = response
fmt.Println()
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {
@ -382,16 +386,11 @@ var newCmd = &cobra.Command{
} }
reply.RenderTTY() reply.RenderTTY()
receiver := make(chan string) response, err := LLMRequest(messages)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("Error while receiving response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
fmt.Println()
reply.OriginalContent = response reply.OriginalContent = response
err = store.SaveMessage(&reply) err = store.SaveMessage(&reply)
@ -432,16 +431,10 @@ var promptCmd = &cobra.Command{
}, },
} }
receiver := make(chan string) _, err := LLMRequest(messages)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
fmt.Println()
}, },
} }
@ -485,25 +478,17 @@ var retryCmd = &cobra.Command{
} }
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) response, err := LLMRequest(messages)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("Error while receiving response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
fmt.Println()
assistantReply.OriginalContent = response assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Fatal("Could not save assistant reply: %v\n", err)
} }
fmt.Println()
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -544,25 +529,17 @@ var continueCmd = &cobra.Command{
} }
assistantReply.RenderTTY() assistantReply.RenderTTY()
receiver := make(chan string) response, err := LLMRequest(messages)
go HandleDelayedContent(receiver)
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
err = HandlePartialResponse(response, err)
if err != nil { if err != nil {
Fatal("Error while receiving response: %v\n", err) Fatal("Error fetching LLM response: %v\n", err)
} }
fmt.Println()
assistantReply.OriginalContent = response assistantReply.OriginalContent = response
err = store.SaveMessage(&assistantReply) err = store.SaveMessage(&assistantReply)
if err != nil { if err != nil {
Fatal("Could not save assistant reply: %v\n", err) Fatal("Could not save assistant reply: %v\n", err)
} }
fmt.Println()
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp

View File

@ -45,8 +45,6 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
client := openai.NewClient(*config.OpenAI.APIKey) client := openai.NewClient(*config.OpenAI.APIKey)
req := CreateChatCompletionRequest(model, messages, maxTokens) req := CreateChatCompletionRequest(model, messages, maxTokens)
defer close(output)
stream, err := client.CreateChatCompletionStream(context.Background(), req) stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil { if err != nil {
return "", err return "", err