Compare commits
2 Commits
045146bb5c
...
91d3c9c2e1
Author | SHA1 | Date | |
---|---|---|---|
91d3c9c2e1 | |||
8bdb155bf7 |
@ -44,13 +44,13 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
fmt.Print(lastMessage.Content)
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
// Submit the LLM request, allowing it to continue the last message
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the new response to the original message
|
// Append the new response to the original message
|
||||||
lastMessage.Content += strings.TrimRight(continuedOutput[0].Content, "\n\t ")
|
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
||||||
|
|
||||||
// Update the original message
|
// Update the original message
|
||||||
err = ctx.Store.UpdateMessage(lastMessage)
|
err = ctx.Store.UpdateMessage(lastMessage)
|
||||||
|
@ -31,7 +31,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages)
|
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
|
|
||||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||||
// the response to stdout. Returns all model reply messages.
|
// the response to stdout. Returns all model reply messages.
|
||||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]model.Message, error) {
|
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
content := make(chan string) // receives the reponse from LLM
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
@ -23,7 +24,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
|
|
||||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
requestParams := model.RequestParameters{
|
||||||
@ -33,9 +34,8 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
ToolBag: ctx.EnabledTools,
|
ToolBag: ctx.EnabledTools,
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiReplies []model.Message
|
|
||||||
response, err := completionProvider.CreateChatCompletionStream(
|
response, err := completionProvider.CreateChatCompletionStream(
|
||||||
requestParams, messages, &apiReplies, content,
|
context.Background(), requestParams, messages, callback, content,
|
||||||
)
|
)
|
||||||
if response != "" {
|
if response != "" {
|
||||||
// there was some content, so break to a new line after it
|
// there was some content, so break to a new line after it
|
||||||
@ -46,8 +46,7 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message) ([]mod
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return response, nil
|
||||||
return apiReplies, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
@ -98,20 +97,21 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
|||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||||
|
|
||||||
replies, err := FetchAndShowCompletion(ctx, allMessages)
|
replyCallback := func(reply model.Message) {
|
||||||
if err != nil {
|
if !persist {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if persist {
|
|
||||||
for _, reply := range replies {
|
|
||||||
reply.ConversationID = c.ID
|
reply.ConversationID = c.ID
|
||||||
|
|
||||||
err = ctx.Store.SaveMessage(&reply)
|
err = ctx.Store.SaveMessage(&reply)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
|||||||
MaxTokens: 25,
|
MaxTokens: 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := completionProvider.CreateChatCompletion(requestParams, generateRequest, nil)
|
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -3,14 +3,15 @@ package anthropic
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AnthropicClient struct {
|
type AnthropicClient struct {
|
||||||
@ -102,7 +103,7 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
return requestBody
|
return requestBody
|
||||||
}
|
}
|
||||||
|
|
||||||
func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||||
url := "https://api.anthropic.com/v1/messages"
|
url := "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(r)
|
jsonBody, err := json.Marshal(r)
|
||||||
@ -110,7 +111,7 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
|||||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||||
}
|
}
|
||||||
@ -129,13 +130,14 @@ func sendRequest(c *AnthropicClient, r Request) (*http.Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
|
|
||||||
resp, err := sendRequest(c, request)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -160,22 +162,25 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||||
}
|
}
|
||||||
*replies = append(*replies, reply)
|
if callback != nil {
|
||||||
|
callback(reply)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
request := buildRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
request.Stream = true
|
request.Stream = true
|
||||||
|
|
||||||
resp, err := sendRequest(c, request)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -288,23 +293,25 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
ToolResults: toolResults,
|
ToolResults: toolResults,
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callback != nil {
|
||||||
*replies = append(append(*replies, toolCall), toolReply)
|
callback(toolCall)
|
||||||
|
callback(toolReply)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
// added to the original messages
|
// added to the original messages
|
||||||
messages = append(append(messages, toolCall), toolReply)
|
messages = append(append(messages, toolCall), toolReply)
|
||||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// return the completed message
|
// return the completed message
|
||||||
reply := model.Message{
|
if callback != nil {
|
||||||
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: sb.String(),
|
Content: sb.String(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
*replies = append(*replies, reply)
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
case "error":
|
case "error":
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
@ -157,13 +158,14 @@ func handleToolCalls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback provider.ReplyCallback,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
resp, err := client.CreateChatCompletion(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -176,17 +178,19 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
if results != nil {
|
if callback != nil {
|
||||||
*replies = append(*replies, results...)
|
for _, result := range results {
|
||||||
|
callback(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletion with the tool call replies
|
// Recurse into CreateChatCompletion with the tool call replies
|
||||||
messages = append(messages, results...)
|
messages = append(messages, results...)
|
||||||
return c.CreateChatCompletion(params, messages, replies)
|
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callback != nil {
|
||||||
*replies = append(*replies, model.Message{
|
callback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: choice.Message.Content,
|
Content: choice.Message.Content,
|
||||||
})
|
})
|
||||||
@ -197,15 +201,16 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callbback provider.ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
client := openai.NewClient(c.APIKey)
|
client := openai.NewClient(c.APIKey)
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
req := createChatCompletionRequest(c, params, messages)
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -250,17 +255,20 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return content.String(), err
|
return content.String(), err
|
||||||
}
|
}
|
||||||
if results != nil {
|
|
||||||
*replies = append(*replies, results...)
|
if callbback != nil {
|
||||||
|
for _, result := range results {
|
||||||
|
callbback(result)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||||
messages = append(messages, results...)
|
messages = append(messages, results...)
|
||||||
return c.CreateChatCompletionStream(params, messages, replies, output)
|
return c.CreateChatCompletionStream(ctx, params, messages, callbback, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
if replies != nil {
|
if callbback != nil {
|
||||||
*replies = append(*replies, model.Message{
|
callbback(model.Message{
|
||||||
Role: model.MessageRoleAssistant,
|
Role: model.MessageRoleAssistant,
|
||||||
Content: content.String(),
|
Content: content.String(),
|
||||||
})
|
})
|
||||||
|
@ -1,23 +1,31 @@
|
|||||||
package provider
|
package provider
|
||||||
|
|
||||||
import "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ReplyCallback func(model.Message)
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
type ChatCompletionClient interface {
|
||||||
// CreateChatCompletion requests a response to the provided messages.
|
// CreateChatCompletion requests a response to the provided messages.
|
||||||
// Replies are appended to the given replies struct, and the
|
// Replies are appended to the given replies struct, and the
|
||||||
// complete user-facing response is returned as a string.
|
// complete user-facing response is returned as a string.
|
||||||
CreateChatCompletion(
|
CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback ReplyCallback,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
|
|
||||||
// Like CreateChageCompletion, except the response is streamed via
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
// the output channel as it's received.
|
// the output channel as it's received.
|
||||||
CreateChatCompletionStream(
|
CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
params model.RequestParameters,
|
params model.RequestParameters,
|
||||||
messages []model.Message,
|
messages []model.Message,
|
||||||
replies *[]model.Message,
|
callback ReplyCallback,
|
||||||
output chan<- string,
|
output chan<- string,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user