Update command flag handling
`lmcli chat` now supports common prompt flags (model, length, system prompt, etc)
This commit is contained in:
parent
8e4ff90ab4
commit
2b38db7db7
@ -33,5 +33,6 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -8,10 +8,6 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
systemPromptFile string
|
|
||||||
)
|
|
||||||
|
|
||||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
var root = &cobra.Command{
|
var root = &cobra.Command{
|
||||||
Use: "lmcli <command> [flags]",
|
Use: "lmcli <command> [flags]",
|
||||||
@ -23,58 +19,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
chatCmd := ChatCmd(ctx)
|
|
||||||
continueCmd := ContinueCmd(ctx)
|
|
||||||
cloneCmd := CloneCmd(ctx)
|
|
||||||
editCmd := EditCmd(ctx)
|
|
||||||
listCmd := ListCmd(ctx)
|
|
||||||
newCmd := NewCmd(ctx)
|
|
||||||
promptCmd := PromptCmd(ctx)
|
|
||||||
renameCmd := RenameCmd(ctx)
|
|
||||||
replyCmd := ReplyCmd(ctx)
|
|
||||||
retryCmd := RetryCmd(ctx)
|
|
||||||
rmCmd := RemoveCmd(ctx)
|
|
||||||
viewCmd := ViewCmd(ctx)
|
|
||||||
|
|
||||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
|
||||||
for _, cmd := range inputCmds {
|
|
||||||
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
|
||||||
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
|
||||||
})
|
|
||||||
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
|
||||||
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
|
||||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
|
||||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
|
||||||
}
|
|
||||||
|
|
||||||
root.AddCommand(
|
root.AddCommand(
|
||||||
chatCmd,
|
ChatCmd(ctx),
|
||||||
cloneCmd,
|
ContinueCmd(ctx),
|
||||||
continueCmd,
|
CloneCmd(ctx),
|
||||||
editCmd,
|
EditCmd(ctx),
|
||||||
listCmd,
|
ListCmd(ctx),
|
||||||
newCmd,
|
NewCmd(ctx),
|
||||||
promptCmd,
|
PromptCmd(ctx),
|
||||||
renameCmd,
|
RenameCmd(ctx),
|
||||||
replyCmd,
|
ReplyCmd(ctx),
|
||||||
retryCmd,
|
RetryCmd(ctx),
|
||||||
rmCmd,
|
RemoveCmd(ctx),
|
||||||
viewCmd,
|
ViewCmd(ctx),
|
||||||
)
|
)
|
||||||
|
|
||||||
return root
|
return root
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSystemPrompt(ctx *lmcli.Context) string {
|
func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) {
|
||||||
if systemPromptFile != "" {
|
f := cmd.Flags()
|
||||||
content, err := util.ReadFileContents(systemPromptFile)
|
|
||||||
if err != nil {
|
f.StringVarP(
|
||||||
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
ctx.Config.Defaults.Model,
|
||||||
}
|
"model", "m",
|
||||||
return content
|
*ctx.Config.Defaults.Model,
|
||||||
}
|
"The model to generate a response with",
|
||||||
return *ctx.Config.Defaults.SystemPrompt
|
)
|
||||||
|
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
||||||
|
})
|
||||||
|
|
||||||
|
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||||
|
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
|
||||||
|
|
||||||
|
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||||
|
f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||||
|
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||||
}
|
}
|
||||||
|
|
||||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
|
@ -44,7 +44,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
fmt.Print(lastMessage.Content)
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
// Submit the LLM request, allowing it to continue the last message
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
@ -68,5 +68,6 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -30,7 +30,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
{
|
{
|
||||||
ConversationID: conversation.ID,
|
ConversationID: conversation.ID,
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: getSystemPrompt(ctx),
|
Content: ctx.GetSystemPrompt(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ConversationID: conversation.ID,
|
ConversationID: conversation.ID,
|
||||||
@ -56,5 +56,6 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
messages := []model.Message{
|
messages := []model.Message{
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleSystem,
|
Role: model.MessageRoleSystem,
|
||||||
Content: getSystemPrompt(ctx),
|
Content: ctx.GetSystemPrompt(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleUser,
|
Role: model.MessageRoleUser,
|
||||||
@ -31,12 +31,14 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
_, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -45,5 +45,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -54,5 +54,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyPromptFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,9 @@ import (
|
|||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
|
|
||||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// the response to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
content := make(chan string) // receives the reponse from LLM
|
||||||
defer close(content)
|
defer close(content)
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
_, err = Prompt(ctx, allMessages, replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -10,17 +10,20 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Context struct {
|
type Context struct {
|
||||||
Config *Config
|
Config *Config // may be updated at runtime
|
||||||
Store ConversationStore
|
Store ConversationStore
|
||||||
|
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma *tty.ChromaHighlighter
|
||||||
EnabledTools []model.Tool
|
EnabledTools []model.Tool
|
||||||
|
|
||||||
|
SystemPromptFile string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContext() (*Context, error) {
|
func NewContext() (*Context, error) {
|
||||||
@ -50,7 +53,7 @@ func NewContext() (*Context, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Context{config, store, chroma, enabledTools}, nil
|
return &Context{config, store, chroma, enabledTools, ""}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
@ -96,6 +99,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl
|
|||||||
return nil, fmt.Errorf("unknown model: %s", model)
|
return nil, fmt.Errorf("unknown model: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetSystemPrompt() string {
|
||||||
|
if c.SystemPromptFile != "" {
|
||||||
|
content, err := util.ReadFileContents(c.SystemPromptFile)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return *c.Config.Defaults.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
func configDir() string {
|
func configDir() string {
|
||||||
var configDir string
|
var configDir string
|
||||||
|
|
||||||
|
@ -33,11 +33,11 @@ type Conversation struct {
|
|||||||
|
|
||||||
type RequestParameters struct {
|
type RequestParameters struct {
|
||||||
Model string
|
Model string
|
||||||
|
|
||||||
MaxTokens int
|
MaxTokens int
|
||||||
Temperature float32
|
Temperature float32
|
||||||
TopP float32
|
TopP float32
|
||||||
|
|
||||||
SystemPrompt string
|
|
||||||
ToolBag []Tool
|
ToolBag []Tool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ
|
|||||||
requestBody := Request{
|
requestBody := Request{
|
||||||
Model: params.Model,
|
Model: params.Model,
|
||||||
Messages: make([]Message, len(messages)),
|
Messages: make([]Message, len(messages)),
|
||||||
System: params.SystemPrompt,
|
|
||||||
MaxTokens: params.MaxTokens,
|
MaxTokens: params.MaxTokens,
|
||||||
Temperature: params.Temperature,
|
Temperature: params.Temperature,
|
||||||
Stream: false,
|
Stream: false,
|
||||||
|
Loading…
Reference in New Issue
Block a user