Refactor the last refactor :)
Removed HandlePartialResponse, add LLMRequest which handles all common logic of making LLM requests and returning/showing their response.
This commit is contained in:
parent
6249fbc8f8
commit
c02b21ca37
@ -53,19 +53,28 @@ func SystemPrompt() string {
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// HandlePartialResponse accepts a response and an err. If err is nil, it does
|
||||
// nothing and returns nil. If response != "" and err != nil, it prints a
|
||||
// warning and returns nil. If response == "" and err != nil, it returns the
|
||||
// error.
|
||||
func HandlePartialResponse(response string, err error) (e error) {
|
||||
if err != nil {
|
||||
if response != "" {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
} else {
|
||||
e = err
|
||||
}
|
||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
||||
// response to stdout, and returns the (partial) response or any errors.
|
||||
func LLMRequest(messages []Message) (string, error) {
|
||||
// receiver receives the reponse from LLM
|
||||
receiver := make(chan string)
|
||||
defer close(receiver)
|
||||
|
||||
// start HandleDelayedContent goroutine to print received data to stdout
|
||||
go HandleDelayedContent(receiver)
|
||||
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if err != nil && response != "" {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil // ignore partial response error
|
||||
}
|
||||
return
|
||||
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||
@ -314,17 +323,12 @@ var replyCmd = &cobra.Command{
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
fmt.Println()
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
@ -382,16 +386,11 @@ var newCmd = &cobra.Command{
|
||||
}
|
||||
reply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
reply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
@ -432,16 +431,10 @@ var promptCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
_, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
},
|
||||
}
|
||||
|
||||
@ -485,25 +478,17 @@ var retryCmd = &cobra.Command{
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
@ -544,25 +529,17 @@ var continueCmd = &cobra.Command{
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedContent(receiver)
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
|
||||
err = HandlePartialResponse(response, err)
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error while receiving response: %v\n", err)
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
|
@ -45,8 +45,6 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
defer close(output)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
Loading…
Reference in New Issue
Block a user