Moved api.ChatCompletionProvider, api.Chunk to api/provider
This commit is contained in:
parent
a441866f2f
commit
327a128b2f
@ -1,6 +1,9 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import "database/sql"
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
@ -9,3 +12,70 @@ type Conversation struct {
|
|||||||
SelectedRootID *uint
|
SelectedRootID *uint
|
||||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleToolCall MessageRole = "tool_call"
|
||||||
|
MessageRoleToolResult MessageRole = "tool_result"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ConversationID *uint `gorm:"index"`
|
||||||
|
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||||
|
Content string
|
||||||
|
Role MessageRole
|
||||||
|
CreatedAt time.Time
|
||||||
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||||
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||||
|
|
||||||
|
SelectedReplyID *uint
|
||||||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||||
|
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||||
|
if force {
|
||||||
|
m[0].Content = system
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
} else {
|
||||||
|
return append([]Message{{
|
||||||
|
Role: MessageRoleSystem,
|
||||||
|
Content: system,
|
||||||
|
}}, m...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageRole) IsAssistant() bool {
|
||||||
|
switch *m {
|
||||||
|
case MessageRoleAssistant, MessageRoleToolCall:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
|
func (m MessageRole) FriendlyRole() string {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleUser:
|
||||||
|
return "You"
|
||||||
|
case MessageRoleSystem:
|
||||||
|
return "System"
|
||||||
|
case MessageRoleAssistant:
|
||||||
|
return "Assistant"
|
||||||
|
case MessageRoleToolCall:
|
||||||
|
return "Tool Call"
|
||||||
|
case MessageRoleToolResult:
|
||||||
|
return "Tool Result"
|
||||||
|
default:
|
||||||
|
return string(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageRole string
|
|
||||||
|
|
||||||
const (
|
|
||||||
MessageRoleSystem MessageRole = "system"
|
|
||||||
MessageRoleUser MessageRole = "user"
|
|
||||||
MessageRoleAssistant MessageRole = "assistant"
|
|
||||||
MessageRoleToolCall MessageRole = "tool_call"
|
|
||||||
MessageRoleToolResult MessageRole = "tool_result"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ConversationID *uint `gorm:"index"`
|
|
||||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
|
||||||
Content string
|
|
||||||
Role MessageRole
|
|
||||||
CreatedAt time.Time
|
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
|
||||||
ToolResults ToolResults // a json array of tool results
|
|
||||||
ParentID *uint
|
|
||||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
|
||||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
|
||||||
|
|
||||||
SelectedReplyID *uint
|
|
||||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
|
||||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
|
||||||
if force {
|
|
||||||
m[0].Content = system
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
} else {
|
|
||||||
return append([]Message{{
|
|
||||||
Role: MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
}}, m...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MessageRole) IsAssistant() bool {
|
|
||||||
switch *m {
|
|
||||||
case MessageRoleAssistant, MessageRoleToolCall:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
|
||||||
func (m MessageRole) FriendlyRole() string {
|
|
||||||
switch m {
|
|
||||||
case MessageRoleUser:
|
|
||||||
return "You"
|
|
||||||
case MessageRoleSystem:
|
|
||||||
return "System"
|
|
||||||
case MessageRoleAssistant:
|
|
||||||
return "Assistant"
|
|
||||||
case MessageRoleToolCall:
|
|
||||||
return "Tool Call"
|
|
||||||
case MessageRoleToolResult:
|
|
||||||
return "Tool Result"
|
|
||||||
default:
|
|
||||||
return string(m)
|
|
||||||
}
|
|
||||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ANTHROPIC_VERSION = "2023-06-01"
|
const ANTHROPIC_VERSION = "2023-06-01"
|
||||||
@ -117,7 +118,7 @@ func convertTools(tools []api.ToolSpec) []Tool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
func createChatCompletionRequest(
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (string, ChatCompletionRequest) {
|
) (string, ChatCompletionRequest) {
|
||||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
@ -188,7 +189,8 @@ func createChatCompletionRequest(
|
|||||||
}
|
}
|
||||||
|
|
||||||
var prefill string
|
var prefill string
|
||||||
if api.IsAssistantContinuation(messages) {
|
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
||||||
|
// Prompting on an assitant message, use its content as prefill
|
||||||
prefill = messages[len(messages)-1].Content
|
prefill = messages[len(messages)-1].Content
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +228,7 @@ func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionReque
|
|||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
@ -253,9 +255,9 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
output chan<- api.Chunk,
|
output chan<- provider.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("can't create completion from no messages")
|
return nil, fmt.Errorf("can't create completion from no messages")
|
||||||
@ -349,7 +351,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
firstChunkReceived = true
|
firstChunkReceived = true
|
||||||
}
|
}
|
||||||
block.Text += text
|
block.Text += text
|
||||||
output <- api.Chunk{
|
output <- provider.Chunk{
|
||||||
Content: text,
|
Content: text,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@ -172,7 +173,7 @@ func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionRespons
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createGenerateContentRequest(
|
func createGenerateContentRequest(
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*GenerateContentRequest, error) {
|
) (*GenerateContentRequest, error) {
|
||||||
requestContents := make([]Content, 0, len(messages))
|
requestContents := make([]Content, 0, len(messages))
|
||||||
@ -279,7 +280,7 @@ func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
|||||||
|
|
||||||
func (c *Client) CreateChatCompletion(
|
func (c *Client) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
@ -351,9 +352,9 @@ func (c *Client) CreateChatCompletion(
|
|||||||
|
|
||||||
func (c *Client) CreateChatCompletionStream(
|
func (c *Client) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
output chan<- api.Chunk,
|
output chan<- provider.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -425,7 +426,7 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
if part.FunctionCall != nil {
|
if part.FunctionCall != nil {
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
} else if part.Text != "" {
|
} else if part.Text != "" {
|
||||||
output <- api.Chunk{
|
output <- provider.Chunk{
|
||||||
Content: part.Text,
|
Content: part.Text,
|
||||||
TokenCount: uint(tokens),
|
TokenCount: uint(tokens),
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OllamaClient struct {
|
type OllamaClient struct {
|
||||||
@ -42,7 +43,7 @@ type OllamaResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createOllamaRequest(
|
func createOllamaRequest(
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) OllamaRequest {
|
) OllamaRequest {
|
||||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||||
@ -82,7 +83,7 @@ func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
|||||||
|
|
||||||
func (c *OllamaClient) CreateChatCompletion(
|
func (c *OllamaClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
@ -122,9 +123,9 @@ func (c *OllamaClient) CreateChatCompletion(
|
|||||||
|
|
||||||
func (c *OllamaClient) CreateChatCompletionStream(
|
func (c *OllamaClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
output chan<- api.Chunk,
|
output chan<- provider.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -173,7 +174,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(streamResp.Message.Content) > 0 {
|
if len(streamResp.Message.Content) > 0 {
|
||||||
output <- api.Chunk{
|
output <- provider.Chunk{
|
||||||
Content: streamResp.Message.Content,
|
Content: streamResp.Message.Content,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIClient struct {
|
type OpenAIClient struct {
|
||||||
@ -140,7 +141,7 @@ func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
func createChatCompletionRequest(
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) ChatCompletionRequest {
|
) ChatCompletionRequest {
|
||||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
@ -219,7 +220,7 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest)
|
|||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
@ -267,9 +268,9 @@ func (c *OpenAIClient) CreateChatCompletion(
|
|||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params api.RequestParameters,
|
params provider.RequestParameters,
|
||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
output chan<- api.Chunk,
|
output chan<- provider.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
@ -333,7 +334,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(delta.Content) > 0 {
|
if len(delta.Content) > 0 {
|
||||||
output <- api.Chunk{
|
output <- provider.Chunk{
|
||||||
Content: delta.Content,
|
Content: delta.Content,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package api
|
package provider
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ReplyCallback func(Message)
|
type ReplyCallback func(api.Message)
|
||||||
|
|
||||||
type Chunk struct {
|
type Chunk struct {
|
||||||
Content string
|
Content string
|
||||||
@ -18,7 +20,7 @@ type RequestParameters struct {
|
|||||||
Temperature float32
|
Temperature float32
|
||||||
TopP float32
|
TopP float32
|
||||||
|
|
||||||
Toolbox []ToolSpec
|
Toolbox []api.ToolSpec
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionProvider interface {
|
type ChatCompletionProvider interface {
|
||||||
@ -28,22 +30,15 @@ type ChatCompletionProvider interface {
|
|||||||
CreateChatCompletion(
|
CreateChatCompletion(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params RequestParameters,
|
params RequestParameters,
|
||||||
messages []Message,
|
messages []api.Message,
|
||||||
) (*Message, error)
|
) (*api.Message, error)
|
||||||
|
|
||||||
// Like CreateChageCompletion, except the response is streamed via
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
// the output channel as it's received.
|
// the output channel as it's received.
|
||||||
CreateChatCompletionStream(
|
CreateChatCompletionStream(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
params RequestParameters,
|
params RequestParameters,
|
||||||
messages []Message,
|
messages []api.Message,
|
||||||
chunks chan<- Chunk,
|
chunks chan<- Chunk,
|
||||||
) (*Message, error)
|
) (*api.Message, error)
|
||||||
}
|
|
||||||
|
|
||||||
func IsAssistantContinuation(messages []Message) bool {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return messages[len(messages)-1].Role == MessageRoleAssistant
|
|
||||||
}
|
}
|
@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
@ -17,12 +18,12 @@ import (
|
|||||||
// Prompt prompts the configured the configured model and streams the response
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
|
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
|
||||||
m, _, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := api.RequestParameters{
|
params := provider.RequestParameters{
|
||||||
Model: m,
|
Model: m,
|
||||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *ctx.Config.Defaults.Temperature,
|
Temperature: *ctx.Config.Defaults.Temperature,
|
||||||
@ -42,13 +43,13 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
messages = api.ApplySystemPrompt(messages, system, false)
|
messages = api.ApplySystemPrompt(messages, system, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make(chan api.Chunk)
|
content := make(chan provider.Chunk)
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
// render the content received over the channel
|
// render the content received over the channel
|
||||||
go ShowDelayedContent(content)
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
reply, err := provider.CreateChatCompletionStream(
|
reply, err := p.CreateChatCompletionStream(
|
||||||
context.Background(), params, messages, content,
|
context.Background(), params, messages, content,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -204,19 +205,19 @@ Example response:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, _, provider, err := ctx.GetModelProvider(
|
m, _, p, err := ctx.GetModelProvider(
|
||||||
*ctx.Config.Conversations.TitleGenerationModel, "",
|
*ctx.Config.Conversations.TitleGenerationModel, "",
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := api.RequestParameters{
|
requestParams := provider.RequestParameters{
|
||||||
Model: m,
|
Model: m,
|
||||||
MaxTokens: 25,
|
MaxTokens: 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := provider.CreateChatCompletion(
|
response, err := p.CreateChatCompletion(
|
||||||
context.Background(), requestParams, generateRequest,
|
context.Background(), requestParams, generateRequest,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -272,7 +273,7 @@ func ShowWaitAnimation(signal chan any) {
|
|||||||
// chunked) content is received on the channel, the waiting animation is
|
// chunked) content is received on the channel, the waiting animation is
|
||||||
// replaced by the content.
|
// replaced by the content.
|
||||||
// Blocks until the channel is closed.
|
// Blocks until the channel is closed.
|
||||||
func ShowDelayedContent(content <-chan api.Chunk) {
|
func ShowDelayedContent(content <-chan provider.Chunk) {
|
||||||
waitSignal := make(chan any)
|
waitSignal := make(chan any)
|
||||||
go ShowWaitAnimation(waitSignal)
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||||
@ -161,7 +162,7 @@ func (c *Context) DefaultSystemPrompt() string {
|
|||||||
return c.Config.Defaults.SystemPrompt
|
return c.Config.Defaults.SystemPrompt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string, provider string) (string, string, api.ChatCompletionProvider, error) {
|
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
if provider == "" && len(parts) > 1 {
|
if provider == "" && len(parts) > 1 {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
@ -24,7 +25,7 @@ type AppModel struct {
|
|||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
Model string
|
Model string
|
||||||
ProviderName string
|
ProviderName string
|
||||||
Provider api.ChatCompletionProvider
|
Provider provider.ChatCompletionProvider
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
|
func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel {
|
||||||
@ -151,6 +152,28 @@ func (a *AppModel) UpdateMessageContent(message *api.Message) error {
|
|||||||
return a.Ctx.Store.UpdateMessage(message)
|
return a.Ctx.Store.UpdateMessage(message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
||||||
|
currentIndex := -1
|
||||||
|
for i, reply := range choices {
|
||||||
|
if reply.ID == selected.ID {
|
||||||
|
currentIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentIndex < 0 {
|
||||||
|
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var next int
|
||||||
|
if dir == CyclePrev {
|
||||||
|
next = (currentIndex - 1 + len(choices)) % len(choices)
|
||||||
|
} else {
|
||||||
|
next = (currentIndex + 1) % len(choices)
|
||||||
|
}
|
||||||
|
return &choices[next], nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
||||||
if len(rootMessages) < 2 {
|
if len(rootMessages) < 2 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -225,13 +248,13 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult,
|
|||||||
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) {
|
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provider.Chunk, stopSignal chan struct{}) (*api.Message, error) {
|
||||||
model, _, provider, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
params := api.RequestParameters{
|
params := provider.RequestParameters{
|
||||||
Model: model,
|
Model: model,
|
||||||
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
||||||
@ -251,29 +274,7 @@ func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Ch
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return provider.CreateChatCompletionStream(
|
return p.CreateChatCompletionStream(
|
||||||
ctx, params, messages, chatReplyChunks,
|
ctx, params, messages, chatReplyChunks,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
|
||||||
currentIndex := -1
|
|
||||||
for i, reply := range choices {
|
|
||||||
if reply.ID == selected.ID {
|
|
||||||
currentIndex = i
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentIndex < 0 {
|
|
||||||
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
var next int
|
|
||||||
if dir == CyclePrev {
|
|
||||||
next = (currentIndex - 1 + len(choices)) % len(choices)
|
|
||||||
} else {
|
|
||||||
next = (currentIndex + 1) % len(choices)
|
|
||||||
}
|
|
||||||
return &choices[next], nil
|
|
||||||
}
|
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"github.com/charmbracelet/bubbles/cursor"
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
"github.com/charmbracelet/bubbles/spinner"
|
"github.com/charmbracelet/bubbles/spinner"
|
||||||
@ -33,7 +34,7 @@ type (
|
|||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
// sent on each chunk received from LLM
|
// sent on each chunk received from LLM
|
||||||
msgChatResponseChunk api.Chunk
|
msgChatResponseChunk provider.Chunk
|
||||||
// sent on each completed reply
|
// sent on each completed reply
|
||||||
msgChatResponse *api.Message
|
msgChatResponse *api.Message
|
||||||
// sent when the response is canceled
|
// sent when the response is canceled
|
||||||
@ -84,7 +85,7 @@ type Model struct {
|
|||||||
editorTarget editorTarget
|
editorTarget editorTarget
|
||||||
stopSignal chan struct{}
|
stopSignal chan struct{}
|
||||||
replyChan chan api.Message
|
replyChan chan api.Message
|
||||||
chatReplyChunks chan api.Chunk
|
chatReplyChunks chan provider.Chunk
|
||||||
persistence bool // whether we will save new messages in the conversation
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
// UI state
|
// UI state
|
||||||
@ -115,7 +116,7 @@ func Chat(app *model.AppModel) *Model {
|
|||||||
|
|
||||||
stopSignal: make(chan struct{}),
|
stopSignal: make(chan struct{}),
|
||||||
replyChan: make(chan api.Message),
|
replyChan: make(chan api.Message),
|
||||||
chatReplyChunks: make(chan api.Chunk),
|
chatReplyChunks: make(chan provider.Chunk),
|
||||||
|
|
||||||
wrap: true,
|
wrap: true,
|
||||||
selectedMessage: -1,
|
selectedMessage: -1,
|
||||||
|
@ -199,10 +199,10 @@ func (m *Model) renderMessage(i int) string {
|
|||||||
|
|
||||||
// render the conversation into a string
|
// render the conversation into a string
|
||||||
func (m *Model) conversationMessagesView() string {
|
func (m *Model) conversationMessagesView() string {
|
||||||
sb := strings.Builder{}
|
|
||||||
|
|
||||||
m.messageOffsets = make([]int, len(m.App.Messages))
|
m.messageOffsets = make([]int, len(m.App.Messages))
|
||||||
lineCnt := 1
|
lineCnt := 1
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
for i, message := range m.App.Messages {
|
for i, message := range m.App.Messages {
|
||||||
m.messageOffsets[i] = lineCnt
|
m.messageOffsets[i] = lineCnt
|
||||||
|
|
||||||
@ -227,7 +227,6 @@ func (m *Model) conversationMessagesView() string {
|
|||||||
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
|
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
|
||||||
sb.WriteString("\n")
|
sb.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user