Add metadata json field to Message, store generation model/provider

This commit is contained in:
Matt Low 2024-09-30 17:37:50 +00:00
parent 327a128b2f
commit 5d13c3e056
4 changed files with 44 additions and 21 deletions

View File

@ -2,6 +2,8 @@ package api
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"encoding/json"
"time" "time"
) )
@ -23,21 +25,36 @@ const (
MessageRoleToolResult MessageRole = "tool_result" MessageRoleToolResult MessageRole = "tool_result"
) )
type MessageMeta struct {
GenerationProvider *string `json:"generation_provider,omitempty"`
GenerationModel *string `json:"generation_model,omitempty"`
}
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
CreatedAt time.Time
Metadata MessageMeta
ConversationID *uint `gorm:"index"` ConversationID *uint `gorm:"index"`
Conversation *Conversation `gorm:"foreignKey:ConversationID"` Conversation *Conversation `gorm:"foreignKey:ConversationID"`
Content string
Role MessageRole
CreatedAt time.Time
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
ParentID *uint ParentID *uint
Parent *Message `gorm:"foreignKey:ParentID"` Parent *Message `gorm:"foreignKey:ParentID"`
Replies []Message `gorm:"foreignKey:ParentID"` Replies []Message `gorm:"foreignKey:ParentID"`
SelectedReplyID *uint SelectedReplyID *uint
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
Role MessageRole
Content string
ToolCalls ToolCalls // a json array of tool calls (from the model)
ToolResults ToolResults // a json array of tool results
}
func (m *MessageMeta) Scan(value interface{}) error {
return json.Unmarshal(value.([]byte), m)
}
func (m MessageMeta) Value() (driver.Value, error) {
return json.Marshal(m)
} }
func ApplySystemPrompt(m []Message, system string, force bool) []Message { func ApplySystemPrompt(m []Message, system string, force bool) []Message {

View File

@ -6,8 +6,6 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api" "git.mlow.ca/mlow/lmcli/pkg/api"
) )
type ReplyCallback func(api.Message)
type Chunk struct { type Chunk struct {
Content string Content string
TokenCount uint TokenCount uint
@ -24,9 +22,8 @@ type RequestParameters struct {
} }
type ChatCompletionProvider interface { type ChatCompletionProvider interface {
// CreateChatCompletion requests a response to the provided messages. // CreateChatCompletion generates a chat completion response to the
// Replies are appended to the given replies struct, and the // provided messages.
// complete user-facing response is returned as a string.
CreateChatCompletion( CreateChatCompletion(
ctx context.Context, ctx context.Context,
params RequestParameters, params RequestParameters,
@ -34,7 +31,7 @@ type ChatCompletionProvider interface {
) (*api.Message, error) ) (*api.Message, error)
// Like CreateChageCompletion, except the response is streamed via // Like CreateChageCompletion, except the response is streamed via
// the output channel as it's received. // the output channel.
CreateChatCompletionStream( CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params RequestParameters, params RequestParameters,

View File

@ -248,7 +248,11 @@ func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult,
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox) return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
} }
func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provider.Chunk, stopSignal chan struct{}) (*api.Message, error) { func (a *AppModel) Prompt(
messages []api.Message,
chatReplyChunks chan provider.Chunk,
stopSignal chan struct{},
) (*api.Message, error) {
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName) model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -274,7 +278,12 @@ func (a *AppModel) PromptLLM(messages []api.Message, chatReplyChunks chan provid
} }
}() }()
return p.CreateChatCompletionStream( msg, err := p.CreateChatCompletionStream(
ctx, params, messages, chatReplyChunks, ctx, params, messages, chatReplyChunks,
) )
if msg != nil {
msg.Metadata.GenerationProvider = &a.ProviderName
msg.Metadata.GenerationModel = &a.Model
}
return msg, err
} }

View File

@ -129,7 +129,7 @@ func (m *Model) promptLLM() tea.Cmd {
m.tokenCount = 0 m.tokenCount = 0
return func() tea.Msg { return func() tea.Msg {
resp, err := m.App.PromptLLM(m.App.Messages, m.chatReplyChunks, m.stopSignal) resp, err := m.App.Prompt(m.App.Messages, m.chatReplyChunks, m.stopSignal)
if err != nil { if err != nil {
return msgChatResponseError{ Err: err } return msgChatResponseError{ Err: err }
} }