Update system prompt handling (again)

Add `api.ApplySystemPrompt`, renamed `GetSystemPrompt` to
`DefaultSystemPrompt`.
This commit is contained in:
Matt Low 2024-06-23 18:35:20 +00:00
parent ba7018af11
commit a43a91c6ff
7 changed files with 39 additions and 43 deletions

View File

@ -31,6 +31,20 @@ type Message struct {
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"` SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
} }
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
if len(m) > 0 && m[0].Role == MessageRoleSystem {
if force {
m[0].Content = system
}
return m
} else {
return append([]Message{{
Role: MessageRoleSystem,
Content: system,
}}, m...)
}
}
func (m *MessageRole) IsAssistant() bool { func (m *MessageRole) IsAssistant() bool {
switch *m { switch *m {
case MessageRoleAssistant, MessageRoleToolCall: case MessageRoleAssistant, MessageRoleToolCall:

View File

@ -25,20 +25,10 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []api.Message messages := []api.Message{{
system := ctx.Config.GetSystemPrompt()
if system != "" {
messages = append(messages, api.Message{
Role: api.MessageRoleSystem,
Content: system,
})
}
messages = append(messages, api.Message{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) }}
conversation, messages, err := ctx.Store.StartConversation(messages...) conversation, messages, err := ctx.Store.StartConversation(messages...)
if err != nil { if err != nil {

View File

@ -25,20 +25,10 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
return fmt.Errorf("No message was provided.") return fmt.Errorf("No message was provided.")
} }
var messages []api.Message messages := []api.Message{{
system := ctx.Config.GetSystemPrompt()
if system != "" {
messages = append(messages, api.Message{
Role: api.MessageRoleSystem,
Content: system,
})
}
messages = append(messages, api.Message{
Role: api.MessageRoleUser, Role: api.MessageRoleUser,
Content: input, Content: input,
}) }}
_, err = cmdutil.Prompt(ctx, messages, nil) _, err = cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {

View File

@ -22,11 +22,15 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
return nil, err return nil, err
} }
requestParams := api.RequestParameters{ params := api.RequestParameters{
Model: m, Model: m,
MaxTokens: *ctx.Config.Defaults.MaxTokens, MaxTokens: *ctx.Config.Defaults.MaxTokens,
Temperature: *ctx.Config.Defaults.Temperature, Temperature: *ctx.Config.Defaults.Temperature,
ToolBag: ctx.EnabledTools, }
system := ctx.DefaultSystemPrompt()
if system != "" {
messages = api.ApplySystemPrompt(messages, system, false)
} }
content := make(chan api.Chunk) content := make(chan api.Chunk)
@ -36,7 +40,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
go ShowDelayedContent(content) go ShowDelayedContent(content)
reply, err := provider.CreateChatCompletionStream( reply, err := provider.CreateChatCompletionStream(
context.Background(), requestParams, messages, content, context.Background(), params, messages, content,
) )
if reply.Content != "" { if reply.Content != "" {

View File

@ -70,14 +70,3 @@ func NewConfig(configFile string) (*Config, error) {
return c, nil return c, nil
} }
func (c *Config) GetSystemPrompt() string {
if c.Defaults.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.Defaults.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.Defaults.SystemPromptFile, err)
}
return content
}
return c.Defaults.SystemPrompt
}

View File

@ -12,6 +12,7 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
@ -81,6 +82,17 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) DefaultSystemPrompt() string {
if c.Config.Defaults.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
}
return content
}
return c.Config.Defaults.SystemPrompt
}
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) { func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
parts := strings.Split(model, "@") parts := strings.Split(model, "@")

View File

@ -143,12 +143,9 @@ func Chat(shared shared.Shared) Model {
m.replyCursor.SetChar(" ") m.replyCursor.SetChar(" ")
m.replyCursor.Focus() m.replyCursor.Focus()
system := shared.Ctx.Config.GetSystemPrompt() system := shared.Ctx.DefaultSystemPrompt()
if system != "" { if system != "" {
m.messages = []api.Message{{ m.messages = api.ApplySystemPrompt(m.messages, system, false)
Role: api.MessageRoleSystem,
Content: system,
}}
} }
m.input.Focus() m.input.Focus()