Compare commits

..

No commits in common. "91d3c9c2e126d39225c0e14129b2b7d1da93ae36" and "045146bb5c57601c787fd67c32bdf503d114cbff" have entirely different histories.

6 changed files with 54 additions and 77 deletions

View File

@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err)
}
// Append the new response to the original message
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
// Update the original message
err = ctx.Store.UpdateMessage(lastMessage)

View File

@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
},
}
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err)
}

View File

@ -1,7 +1,6 @@
package util
import (
"context"
"fmt"
"os"
"strings"
@ -15,7 +14,7 @@ import (
// fetchAndShowCompletion prompts the LLM with the given messages and streams
// the response to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
content := make(chan string) // receives the reponse from LLM
defer close(content)
@ -24,7 +23,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
if err != nil {
return "", err
return nil, err
}
requestParams := model.RequestParameters{
@ -34,8 +33,9 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
ToolBag: ctx.EnabledTools,
}
var apiReplies []model.Message
response, err := completionProvider.CreateChatCompletionStream(
context.Background(), requestParams, messages, callback, content,
requestParams, messages, &apiReplies, content,
)
if response != "" {
// there was some content, so break to a new line after it
@ -46,7 +46,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
err = nil
}
}
return response, nil
return apiReplies, err
}
// lookupConversation either returns the conversation found by the
@ -97,21 +98,20 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
// render a message header with no contents
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
replyCallback := func(reply model.Message) {
if !persist {
return
replies, err := FetchAndShowCompletion(ctx, allMessages)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
if persist {
for _, reply := range replies {
reply.ConversationID = c.ID
err = ctx.Store.SaveMessage(&reply)
if err != nil {
lmcli.Warn("Could not save reply: %v\n", err)
}
}
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err)
}
}
@ -153,7 +153,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
MaxTokens: 25,
}
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
if err != nil {
return "", err
}

View File

@ -3,15 +3,14 @@ package anthropic
import (
"bufio"
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type AnthropicClient struct {
@ -103,7 +102,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
return requestBody
}
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
url := "https://api.anthropic.com/v1/messages"
jsonBody, err := json.Marshal(r)
@ -111,7 +110,7 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
return nil, fmt.Errorf("failed to marshal request body: %v", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
}
@ -130,14 +129,13 @@ func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Resp
}
func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
replies *[]model.Message,
) (string, error) {
request := buildRequest(params, messages)
resp, err := sendRequest(ctx, c, request)
resp, err := sendRequest(c, request)
if err != nil {
return "", err
}
@ -162,25 +160,22 @@ func (c *AnthropicClient) CreateChatCompletion(
default:
return "", fmt.Errorf("unsupported message type: %s", content.Type)
}
if callback != nil {
callback(reply)
}
*replies = append(*replies, reply)
}
return sb.String(), nil
}
func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
replies *[]model.Message,
output chan<- string,
) (string, error) {
request := buildRequest(params, messages)
request.Stream = true
resp, err := sendRequest(ctx, c, request)
resp, err := sendRequest(c, request)
if err != nil {
return "", err
}
@ -293,25 +288,23 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ToolResults: toolResults,
}
if callback != nil {
callback(toolCall)
callback(toolReply)
if replies != nil {
*replies = append(append(*replies, toolCall), toolReply)
}
// Recurse into CreateChatCompletionStream with the tool call replies
// added to the original messages
messages = append(append(messages, toolCall), toolReply)
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
return c.CreateChatCompletionStream(params, messages, replies, output)
}
}
case "message_stop":
// return the completed message
if callback != nil {
callback(model.Message{
reply := model.Message{
Role: model.MessageRoleAssistant,
Content: sb.String(),
})
}
*replies = append(*replies, reply)
return sb.String(), nil
case "error":
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])

View File

@ -9,7 +9,6 @@ import (
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
openai "github.com/sashabaranov/go-openai"
)
@ -158,14 +157,13 @@ func handleToolCalls(
}
func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback provider.ReplyCallback,
replies *[]model.Message,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
resp, err := client.CreateChatCompletion(ctx, req)
resp, err := client.CreateChatCompletion(context.Background(), req)
if err != nil {
return "", err
}
@ -178,19 +176,17 @@ func (c *OpenAIClient) CreateChatCompletion(
if err != nil {
return "", err
}
if callback != nil {
for _, result := range results {
callback(result)
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletion with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletion(ctx, params, messages, callback)
return c.CreateChatCompletion(params, messages, replies)
}
if callback != nil {
callback(model.Message{
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: choice.Message.Content,
})
@ -201,16 +197,15 @@ func (c *OpenAIClient) CreateChatCompletion(
}
func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callbback provider.ReplyCallback,
replies *[]model.Message,
output chan<- string,
) (string, error) {
client := openai.NewClient(c.APIKey)
req := createChatCompletionRequest(c, params, messages)
stream, err := client.CreateChatCompletionStream(ctx, req)
stream, err := client.CreateChatCompletionStream(context.Background(), req)
if err != nil {
return "", err
}
@ -255,20 +250,17 @@ func (c *OpenAIClient) CreateChatCompletionStream(
if err != nil {
return content.String(), err
}
if callbback != nil {
for _, result := range results {
callbback(result)
}
if results != nil {
*replies = append(*replies, results...)
}
// Recurse into CreateChatCompletionStream with the tool call replies
messages = append(messages, results...)
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
return c.CreateChatCompletionStream(params, messages, replies, output)
}
if callbback != nil {
callbback(model.Message{
if replies != nil {
*replies = append(*replies, model.Message{
Role: model.MessageRoleAssistant,
Content: content.String(),
})

View File

@ -1,31 +1,23 @@
package provider
import (
"context"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
)
type ReplyCallback func(model.Message)
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
type ChatCompletionClient interface {
// CreateChatCompletion requests a response to the provided messages.
// Replies are appended to the given replies struct, and the
// complete user-facing response is returned as a string.
CreateChatCompletion(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback ReplyCallback,
replies *[]model.Message,
) (string, error)
// Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received.
CreateChatCompletionStream(
ctx context.Context,
params model.RequestParameters,
messages []model.Message,
callback ReplyCallback,
replies *[]model.Message,
output chan<- string,
) (string, error)
}