Compare commits

..

No commits in common. "91d3c9c2e126d39225c0e14129b2b7d1da93ae36" and "045146bb5c57601c787fd67c32bdf503d114cbff" have entirely different histories.

6 changed files with 54 additions and 77 deletions

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
// Append the new response to the original message // Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ") lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
// Update the original message // Update the original message
err = ctx.Store.UpdateMessage(lastMessage) err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) _, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }

View File

@ -1,7 +1,6 @@
package util package util
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@ -15,7 +14,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams // fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages. // the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -24,7 +23,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model) completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil { if err != nil {
return "", err return nil, err
} }
requestParams := model.RequestParameters{ requestParams := model.RequestParameters{
@ -34,8 +33,9 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
ToolBag: ctx.EnabledTools, ToolBag: ctx.EnabledTools,
} }
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream( response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content, requestParams, messages, &apiReplies, content,
) )
if response != "" { if response != "" {
// there was some content, so break to a new line after it // there was some content, so break to a new line after it
@ -46,7 +46,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
err = nil err = nil
} }
} }
return response, nil
return apiReplies, err
} }
// lookupConversation either returns the conversation found by the // lookupConversation either returns the conversation found by the
@ -97,21 +98,20 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents // render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant})) RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replyCallback := func(reply model.Message) { replies, err := FetchAndShowCompletion(ctx, allMessages)
if !persist { if err != nil {
return lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
if persist {
for _, reply := range replies {
reply.ConversationID = c.ID reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply) err = ctx.Store.SaveMessage(&reply)
if err != nil { if err != nil {
lmcli.Warn("Could not save reply: %v\n", err) lmcli.Warn("Could not save reply: %v\n", err)
} }
} }
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }
} }
@ -153,7 +153,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
MaxTokens: 25, MaxTokens: 25,
} }
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil) response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -3,15 +3,14 @@ package anthropic
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
) )
type AnthropicClient struct { type AnthropicClient struct {
@ -103,7 +102,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
return requestBody return requestBody
} }
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) { func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages" url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r) jsonBody, err := json.Marshal(r)
@ -111,7 +110,7 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
return nil, fmt.Errorf("failed to marshal request body: %v", err) return nil, fmt.Errorf("failed to marshal request body: %v", err)
} }
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody)) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err) return nil, fmt.Errorf("failed to create HTTP request: %v", err)
} }
@ -130,14 +129,13 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
} }
func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(c, request)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -162,25 +160,22 @@ func (c *AnthropicClient) CreateChatCompletion(
default: default:
return "", fmt.Errorf("unsupported message type: %s", content.Type) return "", fmt.Errorf("unsupported message type: %s", content.Type)
} }
if callback != nil { *replies = append(*replies, reply)
callback(reply)
}
} }
return sb.String(), nil return sb.String(), nil
} }
func (c *AnthropicClient) CreateChatCompletionStream( func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
request := buildRequest(params, messages) request := buildRequest(params, messages)
request.Stream = true request.Stream = true
resp, err := sendRequest(ctx, c, request) resp, err := sendRequest(c, request)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -293,25 +288,23 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults, ToolResults: toolResults,
} }
if callback != nil { if replies != nil {
callback(toolCall) *replies = append(append(*replies, toolCall), toolReply)
callback(toolReply)
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages // added to the original messages
messages = append(append(messages, toolCall), toolReply) messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output) return c.CreateChatCompletionStream(params, messages, replies, output)
} }
} }
case "message_stop": case "message_stop":
// return the completed message // return the completed message
if callback != nil { reply := model.Message{
callback(model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: sb.String(), Content: sb.String(),
})
} }
*replies = append(*replies, reply)
return sb.String(), nil return sb.String(), nil
case "error": case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"]) return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,7 +9,6 @@ import (
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
@ -158,14 +157,13 @@ func handleToolCalls(
} }
func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, replies *[]model.Message,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req) resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -178,19 +176,17 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil { if err != nil {
return "", err return "", err
} }
if callback != nil { if results != nil {
for _, result := range results { *replies = append(*replies, results...)
callback(result)
}
} }
// Recurse into CreateChatCompletion with the tool call replies // Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback) return c.CreateChatCompletion(params, messages, replies)
} }
if callback != nil { if replies != nil {
callback(model.Message{ *replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: choice.Message.Content, Content: choice.Message.Content,
}) })
@ -201,16 +197,15 @@ func (c *OpenAIClient) CreateChatCompletion(
} }
func (c *OpenAIClient) CreateChatCompletionStream( func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callbback provider.ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) { ) (string, error) {
client := openai.NewClient(c.APIKey) client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages) req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(ctx, req) stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -255,20 +250,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil { if err != nil {
return content.String(), err return content.String(), err
} }
if results != nil {
if callbback != nil { *replies = append(*replies, results...)
for _, result := range results {
callbback(result)
}
} }
// Recurse into CreateChatCompletionStream with the tool call replies // Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...) messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output) return c.CreateChatCompletionStream(params, messages, replies, output)
} }
if callbback != nil { if replies != nil {
callbback(model.Message{ *replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant, Role: model.MessageRoleAssistant,
Content: content.String(), Content: content.String(),
}) })

View File

@ -1,31 +1,23 @@
package provider package provider
import ( import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type ReplyCallback func(model.Message)
type ChatCompletionClient interface { type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the // Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string. // complete user-facing response is returned as a string.
CreateChatCompletion( CreateChatCompletion(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, replies *[]model.Message,
) (string, error) ) (string, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received. // the output channel as it's received.
CreateChatCompletionStream( CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback ReplyCallback, replies *[]model.Message,
output chan<- string, output chan<- string,
) (string, error) ) (string, error)
} }