From 327a128b2fd2b4f13d31e6f7c9bf075659eb0b33 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 30 Sep 2024 16:14:11 +0000 Subject: [PATCH] Moved api.ChatCompletionProvider, api.Chunk to api/provider --- pkg/api/conversation.go | 72 +++++++++++++++++++++++- pkg/api/message.go | 72 ------------------------ pkg/api/provider/anthropic/anthropic.go | 14 +++-- pkg/api/provider/google/google.go | 11 ++-- pkg/api/provider/ollama/ollama.go | 11 ++-- pkg/api/provider/openai/openai.go | 11 ++-- pkg/api/{api.go => provider/provider.go} | 23 +++----- pkg/cmd/util/util.go | 21 +++---- pkg/lmcli/lmcli.go | 3 +- pkg/tui/model/model.go | 55 +++++++++--------- pkg/tui/views/chat/chat.go | 7 ++- pkg/tui/views/chat/view.go | 5 +- 12 files changed, 153 insertions(+), 152 deletions(-) delete mode 100644 pkg/api/message.go rename pkg/api/{api.go => provider/provider.go} (69%) diff --git a/pkg/api/conversation.go b/pkg/api/conversation.go index 1ee1064..ac972df 100644 --- a/pkg/api/conversation.go +++ b/pkg/api/conversation.go @@ -1,6 +1,9 @@ package api -import "database/sql" +import ( + "database/sql" + "time" +) type Conversation struct { ID uint `gorm:"primaryKey"` @@ -9,3 +12,70 @@ type Conversation struct { SelectedRootID *uint SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"` } + +type MessageRole string + +const ( + MessageRoleSystem MessageRole = "system" + MessageRoleUser MessageRole = "user" + MessageRoleAssistant MessageRole = "assistant" + MessageRoleToolCall MessageRole = "tool_call" + MessageRoleToolResult MessageRole = "tool_result" +) + +type Message struct { + ID uint `gorm:"primaryKey"` + ConversationID *uint `gorm:"index"` + Conversation *Conversation `gorm:"foreignKey:ConversationID"` + Content string + Role MessageRole + CreatedAt time.Time + ToolCalls ToolCalls // a json array of tool calls (from the model) + ToolResults ToolResults // a json array of tool results + ParentID *uint + Parent *Message `gorm:"foreignKey:ParentID"` + Replies []Message `gorm:"foreignKey:ParentID"` + + SelectedReplyID *uint + SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` +} + +func ApplySystemPrompt(m []Message, system string, force bool) []Message { + if len(m) > 0 && m[0].Role == MessageRoleSystem { + if force { + m[0].Content = system + } + return m + } else { + return append([]Message{{ + Role: MessageRoleSystem, + Content: system, + }}, m...) + } +} + +func (m *MessageRole) IsAssistant() bool { + switch *m { + case MessageRoleAssistant, MessageRoleToolCall: + return true + } + return false +} + +// FriendlyRole returns a human friendly signifier for the message's role. +func (m MessageRole) FriendlyRole() string { + switch m { + case MessageRoleUser: + return "You" + case MessageRoleSystem: + return "System" + case MessageRoleAssistant: + return "Assistant" + case MessageRoleToolCall: + return "Tool Call" + case MessageRoleToolResult: + return "Tool Result" + default: + return string(m) + } +} diff --git a/pkg/api/message.go b/pkg/api/message.go deleted file mode 100644 index e51977d..0000000 --- a/pkg/api/message.go +++ /dev/null @@ -1,72 +0,0 @@ -package api - -import ( - "time" -) - -type MessageRole string - -const ( - MessageRoleSystem MessageRole = "system" - MessageRoleUser MessageRole = "user" - MessageRoleAssistant MessageRole = "assistant" - MessageRoleToolCall MessageRole = "tool_call" - MessageRoleToolResult MessageRole = "tool_result" -) - -type Message struct { - ID uint `gorm:"primaryKey"` - ConversationID *uint `gorm:"index"` - Conversation *Conversation `gorm:"foreignKey:ConversationID"` - Content string - Role MessageRole - CreatedAt time.Time - ToolCalls ToolCalls // a json array of tool calls (from the model) - ToolResults ToolResults // a json array of tool results - ParentID *uint - Parent *Message `gorm:"foreignKey:ParentID"` - Replies []Message `gorm:"foreignKey:ParentID"` - - SelectedReplyID *uint - SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` -} - -func ApplySystemPrompt(m []Message, system string, force bool) []Message { - if len(m) > 0 && m[0].Role == MessageRoleSystem { - if force { - m[0].Content = system - } - return m - } else { - return append([]Message{{ - Role: MessageRoleSystem, - Content: system, - }}, m...) - } -} - -func (m *MessageRole) IsAssistant() bool { - switch *m { - case MessageRoleAssistant, MessageRoleToolCall: - return true - } - return false -} - -// FriendlyRole returns a human friendly signifier for the message's role. -func (m MessageRole) FriendlyRole() string { - switch m { - case MessageRoleUser: - return "You" - case MessageRoleSystem: - return "System" - case MessageRoleAssistant: - return "Assistant" - case MessageRoleToolCall: - return "Tool Call" - case MessageRoleToolResult: - return "Tool Result" - default: - return string(m) - } -} diff --git a/pkg/api/provider/anthropic/anthropic.go b/pkg/api/provider/anthropic/anthropic.go index 9b3f25d..4da2ae3 100644 --- a/pkg/api/provider/anthropic/anthropic.go +++ b/pkg/api/provider/anthropic/anthropic.go @@ -11,6 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" ) const ANTHROPIC_VERSION = "2023-06-01" @@ -117,7 +118,7 @@ func convertTools(tools []api.ToolSpec) []Tool { } func createChatCompletionRequest( - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (string, ChatCompletionRequest) { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) @@ -188,7 +189,8 @@ func createChatCompletionRequest( } var prefill string - if api.IsAssistantContinuation(messages) { + if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant { + // Prompting on an assitant message, use its content as prefill prefill = messages[len(messages)-1].Content } @@ -226,7 +228,7 @@ func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionReque func (c *AnthropicClient) CreateChatCompletion( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { @@ -253,9 +255,9 @@ func (c *AnthropicClient) CreateChatCompletion( func (c *AnthropicClient) CreateChatCompletionStream( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, - output chan<- api.Chunk, + output chan<- provider.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("can't create completion from no messages") @@ -349,7 +351,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( firstChunkReceived = true } block.Text += text - output <- api.Chunk{ + output <- provider.Chunk{ Content: text, TokenCount: 1, } diff --git a/pkg/api/provider/google/google.go b/pkg/api/provider/google/google.go index 3290f80..d061d24 100644 --- a/pkg/api/provider/google/google.go +++ b/pkg/api/provider/google/google.go @@ -11,6 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" ) type Client struct { @@ -172,7 +173,7 @@ func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionRespons } func createGenerateContentRequest( - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (*GenerateContentRequest, error) { requestContents := make([]Content, 0, len(messages)) @@ -279,7 +280,7 @@ func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { func (c *Client) CreateChatCompletion( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { @@ -351,9 +352,9 @@ func (c *Client) CreateChatCompletion( func (c *Client) CreateChatCompletionStream( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, - output chan<- api.Chunk, + output chan<- provider.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") @@ -425,7 +426,7 @@ func (c *Client) CreateChatCompletionStream( if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } else if part.Text != "" { - output <- api.Chunk{ + output <- provider.Chunk{ Content: part.Text, TokenCount: uint(tokens), } diff --git a/pkg/api/provider/ollama/ollama.go b/pkg/api/provider/ollama/ollama.go index 960c282..264aca7 100644 --- a/pkg/api/provider/ollama/ollama.go +++ b/pkg/api/provider/ollama/ollama.go @@ -11,6 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" ) type OllamaClient struct { @@ -42,7 +43,7 @@ type OllamaResponse struct { } func createOllamaRequest( - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) OllamaRequest { requestMessages := make([]OllamaMessage, 0, len(messages)) @@ -82,7 +83,7 @@ func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) { func (c *OllamaClient) CreateChatCompletion( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { @@ -122,9 +123,9 @@ func (c *OllamaClient) CreateChatCompletion( func (c *OllamaClient) CreateChatCompletionStream( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, - output chan<- api.Chunk, + output chan<- provider.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") @@ -173,7 +174,7 @@ func (c *OllamaClient) CreateChatCompletionStream( } if len(streamResp.Message.Content) > 0 { - output <- api.Chunk{ + output <- provider.Chunk{ Content: streamResp.Message.Content, TokenCount: 1, } diff --git a/pkg/api/provider/openai/openai.go b/pkg/api/provider/openai/openai.go index 6c376c8..318c392 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/api/provider/openai/openai.go @@ -11,6 +11,7 @@ import ( "strings" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" ) type OpenAIClient struct { @@ -140,7 +141,7 @@ func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall { } func createChatCompletionRequest( - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) ChatCompletionRequest { requestMessages := make([]ChatCompletionMessage, 0, len(messages)) @@ -219,7 +220,7 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) func (c *OpenAIClient) CreateChatCompletion( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { @@ -267,9 +268,9 @@ func (c *OpenAIClient) CreateChatCompletion( func (c *OpenAIClient) CreateChatCompletionStream( ctx context.Context, - params api.RequestParameters, + params provider.RequestParameters, messages []api.Message, - output chan<- api.Chunk, + output chan<- provider.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") @@ -333,7 +334,7 @@ func (c *OpenAIClient) CreateChatCompletionStream( } } if len(delta.Content) > 0 { - output <- api.Chunk{ + output <- provider.Chunk{ Content: delta.Content, TokenCount: 1, } diff --git a/pkg/api/api.go b/pkg/api/provider/provider.go similarity index 69% rename from pkg/api/api.go rename to pkg/api/provider/provider.go index c8c54ad..0e1bdbb 100644 --- a/pkg/api/api.go +++ b/pkg/api/provider/provider.go @@ -1,10 +1,12 @@ -package api +package provider import ( "context" + + "git.mlow.ca/mlow/lmcli/pkg/api" ) -type ReplyCallback func(Message) +type ReplyCallback func(api.Message) type Chunk struct { Content string @@ -18,7 +20,7 @@ type RequestParameters struct { Temperature float32 TopP float32 - Toolbox []ToolSpec + Toolbox []api.ToolSpec } type ChatCompletionProvider interface { @@ -28,22 +30,15 @@ type ChatCompletionProvider interface { CreateChatCompletion( ctx context.Context, params RequestParameters, - messages []Message, - ) (*Message, error) + messages []api.Message, + ) (*api.Message, error) // Like CreateChageCompletion, except the response is streamed via // the output channel as it's received. CreateChatCompletionStream( ctx context.Context, params RequestParameters, - messages []Message, + messages []api.Message, chunks chan<- Chunk, - ) (*Message, error) -} - -func IsAssistantContinuation(messages []Message) bool { - if len(messages) == 0 { - return false - } - return messages[len(messages)-1].Role == MessageRoleAssistant + ) (*api.Message, error) } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 5d853fc..2ca2d6f 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -9,6 +9,7 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/util" "github.com/charmbracelet/lipgloss" @@ -17,12 +18,12 @@ import ( // Prompt prompts the configured the configured model and streams the response // to stdout. Returns all model reply messages. func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) { - m, _, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") + m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "") if err != nil { return nil, err } - params := api.RequestParameters{ + params := provider.RequestParameters{ Model: m, MaxTokens: *ctx.Config.Defaults.MaxTokens, Temperature: *ctx.Config.Defaults.Temperature, @@ -42,13 +43,13 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag messages = api.ApplySystemPrompt(messages, system, false) } - content := make(chan api.Chunk) + content := make(chan provider.Chunk) defer close(content) // render the content received over the channel go ShowDelayedContent(content) - reply, err := provider.CreateChatCompletionStream( + reply, err := p.CreateChatCompletionStream( context.Background(), params, messages, content, ) @@ -182,8 +183,8 @@ Example response: var msgs []msg for _, m := range messages { switch m.Role { - case api.MessageRoleAssistant, api.MessageRoleUser: - msgs = append(msgs, msg{string(m.Role), m.Content}) + case api.MessageRoleAssistant, api.MessageRoleUser: + msgs = append(msgs, msg{string(m.Role), m.Content}) } } @@ -204,19 +205,19 @@ Example response: }, } - m, _, provider, err := ctx.GetModelProvider( + m, _, p, err := ctx.GetModelProvider( *ctx.Config.Conversations.TitleGenerationModel, "", ) if err != nil { return "", err } - requestParams := api.RequestParameters{ + requestParams := provider.RequestParameters{ Model: m, MaxTokens: 25, } - response, err := provider.CreateChatCompletion( + response, err := p.CreateChatCompletion( context.Background(), requestParams, generateRequest, ) if err != nil { @@ -272,7 +273,7 @@ func ShowWaitAnimation(signal chan any) { // chunked) content is received on the channel, the waiting animation is // replaced by the content. // Blocks until the channel is closed. -func ShowDelayedContent(content <-chan api.Chunk) { +func ShowDelayedContent(content <-chan provider.Chunk) { waitSignal := make(chan any) go ShowWaitAnimation(waitSignal) diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index e75bc2a..e8b3d48 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -12,6 +12,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" "git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" @@ -161,7 +162,7 @@ func (c *Context) DefaultSystemPrompt() string { return c.Config.Defaults.SystemPrompt } -func (c *Context) GetModelProvider(model string, provider string) (string, string, api.ChatCompletionProvider, error) { +func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) { parts := strings.Split(model, "@") if provider == "" && len(parts) > 1 { diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 72371d6..74e0651 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -6,6 +6,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/agents" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util" "git.mlow.ca/mlow/lmcli/pkg/lmcli" "github.com/charmbracelet/lipgloss" @@ -24,7 +25,7 @@ type AppModel struct { Messages []api.Message Model string ProviderName string - Provider api.ChatCompletionProvider + Provider provider.ChatCompletionProvider } func NewAppModel(ctx *lmcli.Context, initialConversation *api.Conversation) *AppModel { @@ -151,6 +152,28 @@ func (a *AppModel) UpdateMessageContent(message *api.Message) error { return a.Ctx.Store.UpdateMessage(message) } +func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) { + currentIndex := -1 + for i, reply := range choices { + if reply.ID == selected.ID { + currentIndex = i + break + } + } + + if currentIndex < 0 { + return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID) + } + + var next int + if dir == CyclePrev { + next = (currentIndex - 1 + len(choices)) % len(choices) + } else { + next = (currentIndex + 1) % len(choices) + } + return &choices[next], nil +} + func (a *AppModel) CycleSelectedRoot(conv *api.Conversation, rootMessages []api.Message, dir MessageCycleDirection) (*api.Message, error) { if len(rootMessages) < 2 { return nil, nil @@ -225,13 +248,13 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, return agents.ExecuteToolCalls(toolCalls, agent.Toolbox) } -func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Chunk, stopSignal chan struct{}) (*api.Message, error) { - model, _, provider, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) +func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provider.Chunk, stopSignal chan struct{}) (*api.Message, error) { + model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) if err != nil { return nil, err } - params := api.RequestParameters{ + params := provider.RequestParameters{ Model: model, MaxTokens: *a.Ctx.Config.Defaults.MaxTokens, Temperature: *a.Ctx.Config.Defaults.Temperature, @@ -251,29 +274,7 @@ func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan api.Ch } }() - return provider.CreateChatCompletionStream( + return p.CreateChatCompletionStream( ctx, params, messages, chatReplyChunks, ) } - -func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) { - currentIndex := -1 - for i, reply := range choices { - if reply.ID == selected.ID { - currentIndex = i - break - } - } - - if currentIndex < 0 { - return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID) - } - - var next int - if dir == CyclePrev { - next = (currentIndex - 1 + len(choices)) % len(choices) - } else { - next = (currentIndex + 1) % len(choices) - } - return &choices[next], nil -} diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 3c56fbb..726f66a 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -4,6 +4,7 @@ import ( "time" "git.mlow.ca/mlow/lmcli/pkg/api" + "git.mlow.ca/mlow/lmcli/pkg/api/provider" "git.mlow.ca/mlow/lmcli/pkg/tui/model" "github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/spinner" @@ -33,7 +34,7 @@ type ( Err error } // sent on each chunk received from LLM - msgChatResponseChunk api.Chunk + msgChatResponseChunk provider.Chunk // sent on each completed reply msgChatResponse *api.Message // sent when the response is canceled @@ -84,7 +85,7 @@ type Model struct { editorTarget editorTarget stopSignal chan struct{} replyChan chan api.Message - chatReplyChunks chan api.Chunk + chatReplyChunks chan provider.Chunk persistence bool // whether we will save new messages in the conversation // UI state @@ -115,7 +116,7 @@ func Chat(app *model.AppModel) *Model { stopSignal: make(chan struct{}), replyChan: make(chan api.Message), - chatReplyChunks: make(chan api.Chunk), + chatReplyChunks: make(chan provider.Chunk), wrap: true, selectedMessage: -1, diff --git a/pkg/tui/views/chat/view.go b/pkg/tui/views/chat/view.go index f7adc9e..ba108df 100644 --- a/pkg/tui/views/chat/view.go +++ b/pkg/tui/views/chat/view.go @@ -199,10 +199,10 @@ func (m *Model) renderMessage(i int) string { // render the conversation into a string func (m *Model) conversationMessagesView() string { - sb := strings.Builder{} - m.messageOffsets = make([]int, len(m.App.Messages)) lineCnt := 1 + + sb := strings.Builder{} for i, message := range m.App.Messages { m.messageOffsets[i] = lineCnt @@ -227,7 +227,6 @@ func (m *Model) conversationMessagesView() string { sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View())) sb.WriteString("\n") } - return sb.String() }