Update system prompt handling (again)
Add `api.ApplySystemPrompt`, renamed `GetSystemPrompt` to `DefaultSystemPrompt`.
This commit is contained in:
parent
ba7018af11
commit
a43a91c6ff
@ -31,6 +31,20 @@ type Message struct {
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||
if force {
|
||||
m[0].Content = system
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
return append([]Message{{
|
||||
Role: MessageRoleSystem,
|
||||
Content: system,
|
||||
}}, m...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
|
@ -25,20 +25,10 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
system := ctx.Config.GetSystemPrompt()
|
||||
if system != "" {
|
||||
messages = append(messages, api.Message{
|
||||
Role: api.MessageRoleSystem,
|
||||
Content: system,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{
|
||||
messages := []api.Message{{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
}}
|
||||
|
||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||
if err != nil {
|
||||
|
@ -25,20 +25,10 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
|
||||
system := ctx.Config.GetSystemPrompt()
|
||||
if system != "" {
|
||||
messages = append(messages, api.Message{
|
||||
Role: api.MessageRoleSystem,
|
||||
Content: system,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, api.Message{
|
||||
messages := []api.Message{{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
}}
|
||||
|
||||
_, err = cmdutil.Prompt(ctx, messages, nil)
|
||||
if err != nil {
|
||||
|
@ -22,11 +22,15 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestParams := api.RequestParameters{
|
||||
params := api.RequestParameters{
|
||||
Model: m,
|
||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *ctx.Config.Defaults.Temperature,
|
||||
ToolBag: ctx.EnabledTools,
|
||||
}
|
||||
|
||||
system := ctx.DefaultSystemPrompt()
|
||||
if system != "" {
|
||||
messages = api.ApplySystemPrompt(messages, system, false)
|
||||
}
|
||||
|
||||
content := make(chan api.Chunk)
|
||||
@ -36,7 +40,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
||||
go ShowDelayedContent(content)
|
||||
|
||||
reply, err := provider.CreateChatCompletionStream(
|
||||
context.Background(), requestParams, messages, content,
|
||||
context.Background(), params, messages, content,
|
||||
)
|
||||
|
||||
if reply.Content != "" {
|
||||
|
@ -70,14 +70,3 @@ func NewConfig(configFile string) (*Config, error) {
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Config) GetSystemPrompt() string {
|
||||
if c.Defaults.SystemPromptFile != "" {
|
||||
content, err := util.ReadFileContents(c.Defaults.SystemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v\n", c.Defaults.SystemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return c.Defaults.SystemPrompt
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
@ -81,6 +82,17 @@ func (c *Context) GetModels() (models []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) DefaultSystemPrompt() string {
|
||||
if c.Config.Defaults.SystemPromptFile != "" {
|
||||
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return c.Config.Defaults.SystemPrompt
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
||||
parts := strings.Split(model, "@")
|
||||
|
||||
|
@ -143,12 +143,9 @@ func Chat(shared shared.Shared) Model {
|
||||
m.replyCursor.SetChar(" ")
|
||||
m.replyCursor.Focus()
|
||||
|
||||
system := shared.Ctx.Config.GetSystemPrompt()
|
||||
system := shared.Ctx.DefaultSystemPrompt()
|
||||
if system != "" {
|
||||
m.messages = []api.Message{{
|
||||
Role: api.MessageRoleSystem,
|
||||
Content: system,
|
||||
}}
|
||||
m.messages = api.ApplySystemPrompt(m.messages, system, false)
|
||||
}
|
||||
|
||||
m.input.Focus()
|
||||
|
Loading…
Reference in New Issue
Block a user