Update command flag handling

`lmcli chat` now supports common prompt flags (model, length, system
prompt, etc)
This commit is contained in:
Matt Low 2024-05-07 07:11:04 +00:00
parent 8e4ff90ab4
commit 2b38db7db7
11 changed files with 74 additions and 71 deletions

View File

@ -33,5 +33,6 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -8,10 +8,6 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
var (
systemPromptFile string
)
func RootCmd(ctx *lmcli.Context) *cobra.Command { func RootCmd(ctx *lmcli.Context) *cobra.Command {
var root = &cobra.Command{ var root = &cobra.Command{
Use: "lmcli <command> [flags]", Use: "lmcli <command> [flags]",
@ -23,58 +19,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
chatCmd := ChatCmd(ctx)
continueCmd := ContinueCmd(ctx)
cloneCmd := CloneCmd(ctx)
editCmd := EditCmd(ctx)
listCmd := ListCmd(ctx)
newCmd := NewCmd(ctx)
promptCmd := PromptCmd(ctx)
renameCmd := RenameCmd(ctx)
replyCmd := ReplyCmd(ctx)
retryCmd := RetryCmd(ctx)
rmCmd := RemoveCmd(ctx)
viewCmd := ViewCmd(ctx)
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
for _, cmd := range inputCmds {
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
root.AddCommand( root.AddCommand(
chatCmd, ChatCmd(ctx),
cloneCmd, ContinueCmd(ctx),
continueCmd, CloneCmd(ctx),
editCmd, EditCmd(ctx),
listCmd, ListCmd(ctx),
newCmd, NewCmd(ctx),
promptCmd, PromptCmd(ctx),
renameCmd, RenameCmd(ctx),
replyCmd, ReplyCmd(ctx),
retryCmd, RetryCmd(ctx),
rmCmd, RemoveCmd(ctx),
viewCmd, ViewCmd(ctx),
) )
return root return root
} }
func getSystemPrompt(ctx *lmcli.Context) string { func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) {
if systemPromptFile != "" { f := cmd.Flags()
content, err := util.ReadFileContents(systemPromptFile)
if err != nil { f.StringVarP(
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err) ctx.Config.Defaults.Model,
} "model", "m",
return content *ctx.Config.Defaults.Model,
} "The model to generate a response with",
return *ctx.Config.Defaults.SystemPrompt )
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
} }
// inputFromArgsOrEditor returns either the provided input from the args slice // inputFromArgsOrEditor returns either the provided input from the args slice

View File

@ -44,7 +44,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
fmt.Print(lastMessage.Content) fmt.Print(lastMessage.Content)
// Submit the LLM request, allowing it to continue the last message // Submit the LLM request, allowing it to continue the last message
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("error fetching LLM response: %v", err) return fmt.Errorf("error fetching LLM response: %v", err)
} }
@ -68,5 +68,6 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -28,14 +28,14 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
messages := []model.Message{ messages := []model.Message{
{ {
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: getSystemPrompt(ctx), Content: ctx.GetSystemPrompt(),
}, },
{ {
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: messageContents, Content: messageContents,
}, },
} }
@ -56,5 +56,6 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -22,21 +22,23 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
messages := []model.Message{ messages := []model.Message{
{ {
Role: model.MessageRoleSystem, Role: model.MessageRoleSystem,
Content: getSystemPrompt(ctx), Content: ctx.GetSystemPrompt(),
}, },
{ {
Role: model.MessageRoleUser, Role: model.MessageRoleUser,
Content: message, Content: message,
}, },
} }
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) _, err := cmdutil.Prompt(ctx, messages, nil)
if err != nil { if err != nil {
return fmt.Errorf("Error fetching LLM response: %v", err) return fmt.Errorf("Error fetching LLM response: %v", err)
} }
return nil return nil
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -45,5 +45,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -54,5 +54,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
}, },
} }
applyPromptFlags(ctx, cmd)
return cmd return cmd
} }

View File

@ -13,9 +13,9 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
// fetchAndShowCompletion prompts the LLM with the given messages and streams // Prompt prompts the configured the configured model and streams the response
// the response to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan string) // receives the reponse from LLM content := make(chan string) // receives the reponse from LLM
defer close(content) defer close(content)
@ -109,7 +109,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
} }
} }
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback) _, err = Prompt(ctx, allMessages, replyCallback)
if err != nil { if err != nil {
lmcli.Fatal("Error fetching LLM response: %v\n", err) lmcli.Fatal("Error fetching LLM response: %v\n", err)
} }

View File

@ -10,17 +10,20 @@ import (
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
) )
type Context struct { type Context struct {
Config *Config Config *Config // may be updated at runtime
Store ConversationStore Store ConversationStore
Chroma *tty.ChromaHighlighter Chroma *tty.ChromaHighlighter
EnabledTools []model.Tool EnabledTools []model.Tool
SystemPromptFile string
} }
func NewContext() (*Context, error) { func NewContext() (*Context, error) {
@ -50,7 +53,7 @@ func NewContext() (*Context, error) {
} }
} }
return &Context{config, store, chroma, enabledTools}, nil return &Context{config, store, chroma, enabledTools, ""}, nil
} }
func (c *Context) GetModels() (models []string) { func (c *Context) GetModels() (models []string) {
@ -96,6 +99,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
return nil, fmt.Errorf("unknown model: %s", model) return nil, fmt.Errorf("unknown model: %s", model)
} }
func (c *Context) GetSystemPrompt() string {
if c.SystemPromptFile != "" {
content, err := util.ReadFileContents(c.SystemPromptFile)
if err != nil {
Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err)
}
return content
}
return *c.Config.Defaults.SystemPrompt
}
func configDir() string { func configDir() string {
var configDir string var configDir string

View File

@ -32,13 +32,13 @@ type Conversation struct {
} }
type RequestParameters struct { type RequestParameters struct {
Model string Model string
MaxTokens int MaxTokens int
Temperature float32 Temperature float32
TopP float32 TopP float32
SystemPrompt string ToolBag []Tool
ToolBag []Tool
} }
func (m *MessageRole) IsAssistant() bool { func (m *MessageRole) IsAssistant() bool {

View File

@ -19,7 +19,6 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
requestBody := Request{ requestBody := Request{
Model: params.Model, Model: params.Model,
Messages: make([]Message, len(messages)), Messages: make([]Message, len(messages)),
System: params.SystemPrompt,
MaxTokens: params.MaxTokens, MaxTokens: params.MaxTokens,
Temperature: params.Temperature, Temperature: params.Temperature,
Stream: false, Stream: false,