Refactor the last refactor :)
Removed HandlePartialResponse, add LLMRequest which handles all common logic of making LLM requests and returning/showing their response.
This commit is contained in:
parent
6249fbc8f8
commit
c02b21ca37
@ -53,19 +53,28 @@ func SystemPrompt() string {
|
|||||||
return systemPrompt
|
return systemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandlePartialResponse accepts a response and an err. If err is nil, it does
|
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
||||||
// nothing and returns nil. If response != "" and err != nil, it prints a
|
// response to stdout, and returns the (partial) response or any errors.
|
||||||
// warning and returns nil. If response == "" and err != nil, it returns the
|
func LLMRequest(messages []Message) (string, error) {
|
||||||
// error.
|
// receiver receives the reponse from LLM
|
||||||
func HandlePartialResponse(response string, err error) (e error) {
|
receiver := make(chan string)
|
||||||
if err != nil {
|
defer close(receiver)
|
||||||
if response != "" {
|
|
||||||
Warn("Received partial response. Error: %v\n", err)
|
// start HandleDelayedContent goroutine to print received data to stdout
|
||||||
} else {
|
go HandleDelayedContent(receiver)
|
||||||
e = err
|
|
||||||
}
|
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||||
|
if err != nil && response != "" {
|
||||||
|
Warn("Received partial response. Error: %v\n", err)
|
||||||
|
err = nil // ignore partial response error
|
||||||
}
|
}
|
||||||
return
|
|
||||||
|
if response != "" {
|
||||||
|
// there was some content, so break to a new line after it
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
@ -314,17 +323,12 @@ var replyCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
response, err := LLMRequest(messages)
|
||||||
go HandleDelayedContent(receiver)
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
|
||||||
|
|
||||||
err = HandlePartialResponse(response, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error while receiving response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assistantReply.OriginalContent = response
|
assistantReply.OriginalContent = response
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -382,16 +386,11 @@ var newCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
reply.RenderTTY()
|
reply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
response, err := LLMRequest(messages)
|
||||||
go HandleDelayedContent(receiver)
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
|
||||||
|
|
||||||
err = HandlePartialResponse(response, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error while receiving response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
reply.OriginalContent = response
|
reply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&reply)
|
err = store.SaveMessage(&reply)
|
||||||
@ -432,16 +431,10 @@ var promptCmd = &cobra.Command{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
receiver := make(chan string)
|
_, err := LLMRequest(messages)
|
||||||
go HandleDelayedContent(receiver)
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
|
||||||
|
|
||||||
err = HandlePartialResponse(response, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -485,25 +478,17 @@ var retryCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
response, err := LLMRequest(messages)
|
||||||
go HandleDelayedContent(receiver)
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
|
||||||
|
|
||||||
err = HandlePartialResponse(response, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error while receiving response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
assistantReply.OriginalContent = response
|
assistantReply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Fatal("Could not save assistant reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
@ -544,25 +529,17 @@ var continueCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
assistantReply.RenderTTY()
|
assistantReply.RenderTTY()
|
||||||
|
|
||||||
receiver := make(chan string)
|
response, err := LLMRequest(messages)
|
||||||
go HandleDelayedContent(receiver)
|
|
||||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
|
||||||
|
|
||||||
err = HandlePartialResponse(response, err)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Error while receiving response: %v\n", err)
|
Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
assistantReply.OriginalContent = response
|
assistantReply.OriginalContent = response
|
||||||
|
|
||||||
err = store.SaveMessage(&assistantReply)
|
err = store.SaveMessage(&assistantReply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not save assistant reply: %v\n", err)
|
Fatal("Could not save assistant reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
|
@ -45,8 +45,6 @@ func CreateChatCompletionStream(model string, messages []Message, maxTokens int,
|
|||||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||||
|
|
||||||
defer close(output)
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
Loading…
Reference in New Issue
Block a user