diff --git a/pkg/api/message.go b/pkg/api/message.go index cc7cec1..e51977d 100644 --- a/pkg/api/message.go +++ b/pkg/api/message.go @@ -31,6 +31,20 @@ type Message struct { SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` } +func ApplySystemPrompt(m []Message, system string, force bool) []Message { + if len(m) > 0 && m[0].Role == MessageRoleSystem { + if force { + m[0].Content = system + } + return m + } else { + return append([]Message{{ + Role: MessageRoleSystem, + Content: system, + }}, m...) + } +} + func (m *MessageRole) IsAssistant() bool { switch *m { case MessageRoleAssistant, MessageRoleToolCall: diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index 45a2002..e4a61fc 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -25,20 +25,10 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - var messages []api.Message - - system := ctx.Config.GetSystemPrompt() - if system != "" { - messages = append(messages, api.Message{ - Role: api.MessageRoleSystem, - Content: system, - }) - } - - messages = append(messages, api.Message{ + messages := []api.Message{{ Role: api.MessageRoleUser, Content: input, - }) + }} conversation, messages, err := ctx.Store.StartConversation(messages...) if err != nil { diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 9d17959..abab6b9 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -25,20 +25,10 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { return fmt.Errorf("No message was provided.") } - var messages []api.Message - - system := ctx.Config.GetSystemPrompt() - if system != "" { - messages = append(messages, api.Message{ - Role: api.MessageRoleSystem, - Content: system, - }) - } - - messages = append(messages, api.Message{ + messages := []api.Message{{ Role: api.MessageRoleUser, Content: input, - }) + }} _, err = cmdutil.Prompt(ctx, messages, nil) if err != nil { diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index c4407fa..ba93cd5 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -22,11 +22,15 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag return nil, err } - requestParams := api.RequestParameters{ + params := api.RequestParameters{ Model: m, MaxTokens: *ctx.Config.Defaults.MaxTokens, Temperature: *ctx.Config.Defaults.Temperature, - ToolBag: ctx.EnabledTools, + } + + system := ctx.DefaultSystemPrompt() + if system != "" { + messages = api.ApplySystemPrompt(messages, system, false) } content := make(chan api.Chunk) @@ -36,7 +40,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag go ShowDelayedContent(content) reply, err := provider.CreateChatCompletionStream( - context.Background(), requestParams, messages, content, + context.Background(), params, messages, content, ) if reply.Content != "" { diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index ad25250..5c87114 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -70,14 +70,3 @@ func NewConfig(configFile string) (*Config, error) { return c, nil } - -func (c *Config) GetSystemPrompt() string { - if c.Defaults.SystemPromptFile != "" { - content, err := util.ReadFileContents(c.Defaults.SystemPromptFile) - if err != nil { - Fatal("Could not read file contents at %s: %v\n", c.Defaults.SystemPromptFile, err) - } - return content - } - return c.Defaults.SystemPrompt -} diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 6de906a..b48859d 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -12,6 +12,7 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" + "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -81,6 +82,17 @@ func (c *Context) GetModels() (models []string) { return } +func (c *Context) DefaultSystemPrompt() string { + if c.Config.Defaults.SystemPromptFile != "" { + content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile) + if err != nil { + Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err) + } + return content + } + return c.Config.Defaults.SystemPrompt +} + func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { parts := strings.Split(model, "@") diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index 9984448..c07f4f2 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -143,12 +143,9 @@ func Chat(shared shared.Shared) Model { m.replyCursor.SetChar(" ") m.replyCursor.Focus() - system := shared.Ctx.Config.GetSystemPrompt() + system := shared.Ctx.DefaultSystemPrompt() if system != "" { - m.messages = []api.Message{{ - Role: api.MessageRoleSystem, - Content: system, - }} + m.messages = api.ApplySystemPrompt(m.messages, system, false) } m.input.Focus()