From 5d13c3e05627ca44cb04f559190e419f7389cbe3 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Mon, 30 Sep 2024 17:37:50 +0000 Subject: [PATCH] Add metadata json field to Message, store generation model/provider --- pkg/api/conversation.go | 41 +++++++++++++++++++++++++----------- pkg/api/provider/provider.go | 9 +++----- pkg/tui/model/model.go | 13 ++++++++++-- pkg/tui/views/chat/cmds.go | 2 +- 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/pkg/api/conversation.go b/pkg/api/conversation.go index ac972df..116fd49 100644 --- a/pkg/api/conversation.go +++ b/pkg/api/conversation.go @@ -2,6 +2,8 @@ package api import ( "database/sql" + "database/sql/driver" + "encoding/json" "time" ) @@ -23,21 +25,36 @@ const ( MessageRoleToolResult MessageRole = "tool_result" ) -type Message struct { - ID uint `gorm:"primaryKey"` - ConversationID *uint `gorm:"index"` - Conversation *Conversation `gorm:"foreignKey:ConversationID"` - Content string - Role MessageRole - CreatedAt time.Time - ToolCalls ToolCalls // a json array of tool calls (from the model) - ToolResults ToolResults // a json array of tool results - ParentID *uint - Parent *Message `gorm:"foreignKey:ParentID"` - Replies []Message `gorm:"foreignKey:ParentID"` +type MessageMeta struct { + GenerationProvider *string `json:"generation_provider,omitempty"` + GenerationModel *string `json:"generation_model,omitempty"` +} +type Message struct { + ID uint `gorm:"primaryKey"` + CreatedAt time.Time + Metadata MessageMeta + + ConversationID *uint `gorm:"index"` + Conversation *Conversation `gorm:"foreignKey:ConversationID"` + ParentID *uint + Parent *Message `gorm:"foreignKey:ParentID"` + Replies []Message `gorm:"foreignKey:ParentID"` SelectedReplyID *uint SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` + + Role MessageRole + Content string + ToolCalls ToolCalls // a json array of tool calls (from the model) + ToolResults ToolResults // a json array of tool results +} + +func (m *MessageMeta) Scan(value interface{}) error { + return json.Unmarshal(value.([]byte), m) +} + +func (m MessageMeta) Value() (driver.Value, error) { + return json.Marshal(m) } func ApplySystemPrompt(m []Message, system string, force bool) []Message { diff --git a/pkg/api/provider/provider.go b/pkg/api/provider/provider.go index 0e1bdbb..e14b7da 100644 --- a/pkg/api/provider/provider.go +++ b/pkg/api/provider/provider.go @@ -6,8 +6,6 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api" ) -type ReplyCallback func(api.Message) - type Chunk struct { Content string TokenCount uint @@ -24,9 +22,8 @@ type RequestParameters struct { } type ChatCompletionProvider interface { - // CreateChatCompletion requests a response to the provided messages. - // Replies are appended to the given replies struct, and the - // complete user-facing response is returned as a string. + // CreateChatCompletion generates a chat completion response to the + // provided messages. CreateChatCompletion( ctx context.Context, params RequestParameters, @@ -34,7 +31,7 @@ type ChatCompletionProvider interface { ) (*api.Message, error) // Like CreateChageCompletion, except the response is streamed via - // the output channel as it's received. + // the output channel. CreateChatCompletionStream( ctx context.Context, params RequestParameters, diff --git a/pkg/tui/model/model.go b/pkg/tui/model/model.go index 74e0651..a657b28 100644 --- a/pkg/tui/model/model.go +++ b/pkg/tui/model/model.go @@ -248,7 +248,11 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, return agents.ExecuteToolCalls(toolCalls, agent.Toolbox) } -func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provider.Chunk, stopSignal chan struct{}) (*api.Message, error) { +func (a *AppModel) Prompt( + messages []api.Message, + chatReplyChunks chan provider.Chunk, + stopSignal chan struct{}, +) (*api.Message, error) { model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) if err != nil { return nil, err @@ -274,7 +278,12 @@ func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provid } }() - return p.CreateChatCompletionStream( + msg, err := p.CreateChatCompletionStream( ctx, params, messages, chatReplyChunks, ) + if msg != nil { + msg.Metadata.GenerationProvider = &a.ProviderName + msg.Metadata.GenerationModel = &a.Model + } + return msg, err } diff --git a/pkg/tui/views/chat/cmds.go b/pkg/tui/views/chat/cmds.go index b4b681f..175703f 100644 --- a/pkg/tui/views/chat/cmds.go +++ b/pkg/tui/views/chat/cmds.go @@ -129,7 +129,7 @@ func (m *Model) promptLLM() tea.Cmd { m.tokenCount = 0 return func() tea.Msg { - resp, err := m.App.PromptLLM(m.App.Messages, m.chatReplyChunks, m.stopSignal) + resp, err := m.App.Prompt(m.App.Messages, m.chatReplyChunks, m.stopSignal) if err != nil { return msgChatResponseError{ Err: err } }