Update system prompt handling (again)
Add `api.ApplySystemPrompt`, renamed `GetSystemPrompt` to `DefaultSystemPrompt`.
This commit is contained in:
parent
ba7018af11
commit
a43a91c6ff
@ -31,6 +31,20 @@ type Message struct {
|
|||||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||||
|
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||||
|
if force {
|
||||||
|
m[0].Content = system
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
} else {
|
||||||
|
return append([]Message{{
|
||||||
|
Role: MessageRoleSystem,
|
||||||
|
Content: system,
|
||||||
|
}}, m...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MessageRole) IsAssistant() bool {
|
func (m *MessageRole) IsAssistant() bool {
|
||||||
switch *m {
|
switch *m {
|
||||||
case MessageRoleAssistant, MessageRoleToolCall:
|
case MessageRoleAssistant, MessageRoleToolCall:
|
||||||
|
@ -25,20 +25,10 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []api.Message
|
messages := []api.Message{{
|
||||||
|
|
||||||
system := ctx.Config.GetSystemPrompt()
|
|
||||||
if system != "" {
|
|
||||||
messages = append(messages, api.Message{
|
|
||||||
Role: api.MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = append(messages, api.Message{
|
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
}}
|
||||||
|
|
||||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -25,20 +25,10 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
var messages []api.Message
|
messages := []api.Message{{
|
||||||
|
|
||||||
system := ctx.Config.GetSystemPrompt()
|
|
||||||
if system != "" {
|
|
||||||
messages = append(messages, api.Message{
|
|
||||||
Role: api.MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
messages = append(messages, api.Message{
|
|
||||||
Role: api.MessageRoleUser,
|
Role: api.MessageRoleUser,
|
||||||
Content: input,
|
Content: input,
|
||||||
})
|
}}
|
||||||
|
|
||||||
_, err = cmdutil.Prompt(ctx, messages, nil)
|
_, err = cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -22,11 +22,15 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := api.RequestParameters{
|
params := api.RequestParameters{
|
||||||
Model: m,
|
Model: m,
|
||||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *ctx.Config.Defaults.Temperature,
|
Temperature: *ctx.Config.Defaults.Temperature,
|
||||||
ToolBag: ctx.EnabledTools,
|
}
|
||||||
|
|
||||||
|
system := ctx.DefaultSystemPrompt()
|
||||||
|
if system != "" {
|
||||||
|
messages = api.ApplySystemPrompt(messages, system, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make(chan api.Chunk)
|
content := make(chan api.Chunk)
|
||||||
@ -36,7 +40,7 @@ func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Messag
|
|||||||
go ShowDelayedContent(content)
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
reply, err := provider.CreateChatCompletionStream(
|
reply, err := provider.CreateChatCompletionStream(
|
||||||
context.Background(), requestParams, messages, content,
|
context.Background(), params, messages, content,
|
||||||
)
|
)
|
||||||
|
|
||||||
if reply.Content != "" {
|
if reply.Content != "" {
|
||||||
|
@ -70,14 +70,3 @@ func NewConfig(configFile string) (*Config, error) {
|
|||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) GetSystemPrompt() string {
|
|
||||||
if c.Defaults.SystemPromptFile != "" {
|
|
||||||
content, err := util.ReadFileContents(c.Defaults.SystemPromptFile)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("Could not read file contents at %s: %v\n", c.Defaults.SystemPromptFile, err)
|
|
||||||
}
|
|
||||||
return content
|
|
||||||
}
|
|
||||||
return c.Defaults.SystemPrompt
|
|
||||||
}
|
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -81,6 +82,17 @@ func (c *Context) GetModels() (models []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) DefaultSystemPrompt() string {
|
||||||
|
if c.Config.Defaults.SystemPromptFile != "" {
|
||||||
|
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return c.Config.Defaults.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
||||||
parts := strings.Split(model, "@")
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
|
@ -143,12 +143,9 @@ func Chat(shared shared.Shared) Model {
|
|||||||
m.replyCursor.SetChar(" ")
|
m.replyCursor.SetChar(" ")
|
||||||
m.replyCursor.Focus()
|
m.replyCursor.Focus()
|
||||||
|
|
||||||
system := shared.Ctx.Config.GetSystemPrompt()
|
system := shared.Ctx.DefaultSystemPrompt()
|
||||||
if system != "" {
|
if system != "" {
|
||||||
m.messages = []api.Message{{
|
m.messages = api.ApplySystemPrompt(m.messages, system, false)
|
||||||
Role: api.MessageRoleSystem,
|
|
||||||
Content: system,
|
|
||||||
}}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
m.input.Focus()
|
m.input.Focus()
|
||||||
|
Loading…
Reference in New Issue
Block a user