From 2b38db7db7ca8ebefe1e0f0517e7a058af32c6b5 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Tue, 7 May 2024 07:11:04 +0000 Subject: [PATCH] Update command flag handling `lmcli chat` now supports common prompt flags (model, length, system prompt, etc) --- pkg/cmd/chat.go | 1 + pkg/cmd/cmd.go | 81 +++++++++-------------- pkg/cmd/continue.go | 3 +- pkg/cmd/new.go | 13 ++-- pkg/cmd/prompt.go | 10 +-- pkg/cmd/reply.go | 2 + pkg/cmd/retry.go | 2 + pkg/cmd/util/util.go | 8 +-- pkg/lmcli/lmcli.go | 18 ++++- pkg/lmcli/model/conversation.go | 6 +- pkg/lmcli/provider/anthropic/anthropic.go | 1 - 11 files changed, 74 insertions(+), 71 deletions(-) diff --git a/pkg/cmd/chat.go b/pkg/cmd/chat.go index 0c91719..f86c98d 100644 --- a/pkg/cmd/chat.go +++ b/pkg/cmd/chat.go @@ -33,5 +33,6 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 9a8fa59..593fff2 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -8,10 +8,6 @@ import ( "github.com/spf13/cobra" ) -var ( - systemPromptFile string -) - func RootCmd(ctx *lmcli.Context) *cobra.Command { var root = &cobra.Command{ Use: "lmcli [flags]", @@ -23,58 +19,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command { }, } - chatCmd := ChatCmd(ctx) - continueCmd := ContinueCmd(ctx) - cloneCmd := CloneCmd(ctx) - editCmd := EditCmd(ctx) - listCmd := ListCmd(ctx) - newCmd := NewCmd(ctx) - promptCmd := PromptCmd(ctx) - renameCmd := RenameCmd(ctx) - replyCmd := ReplyCmd(ctx) - retryCmd := RetryCmd(ctx) - rmCmd := RemoveCmd(ctx) - viewCmd := ViewCmd(ctx) - - inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd} - for _, cmd := range inputCmds { - cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use") - cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { - return ctx.GetModels(), cobra.ShellCompDirectiveDefault - }) - cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens") - cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") - cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt") - cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") - } - root.AddCommand( - chatCmd, - cloneCmd, - continueCmd, - editCmd, - listCmd, - newCmd, - promptCmd, - renameCmd, - replyCmd, - retryCmd, - rmCmd, - viewCmd, + ChatCmd(ctx), + ContinueCmd(ctx), + CloneCmd(ctx), + EditCmd(ctx), + ListCmd(ctx), + NewCmd(ctx), + PromptCmd(ctx), + RenameCmd(ctx), + ReplyCmd(ctx), + RetryCmd(ctx), + RemoveCmd(ctx), + ViewCmd(ctx), ) return root } -func getSystemPrompt(ctx *lmcli.Context) string { - if systemPromptFile != "" { - content, err := util.ReadFileContents(systemPromptFile) - if err != nil { - lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err) - } - return content - } - return *ctx.Config.Defaults.SystemPrompt +func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) { + f := cmd.Flags() + + f.StringVarP( + ctx.Config.Defaults.Model, + "model", "m", + *ctx.Config.Defaults.Model, + "The model to generate a response with", + ) + cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { + return ctx.GetModels(), cobra.ShellCompDirectiveDefault + }) + + f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens") + f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature") + + f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") + f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt") + cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") } // inputFromArgsOrEditor returns either the provided input from the args slice diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index 6927db2..0869769 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -44,7 +44,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { fmt.Print(lastMessage.Content) // Submit the LLM request, allowing it to continue the last message - continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) + continuedOutput, err := cmdutil.Prompt(ctx, messages, nil) if err != nil { return fmt.Errorf("error fetching LLM response: %v", err) } @@ -68,5 +68,6 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index 6875681..0ef6a5d 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -28,14 +28,14 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { messages := []model.Message{ { - ConversationID: conversation.ID, - Role: model.MessageRoleSystem, - Content: getSystemPrompt(ctx), + ConversationID: conversation.ID, + Role: model.MessageRoleSystem, + Content: ctx.GetSystemPrompt(), }, { - ConversationID: conversation.ID, - Role: model.MessageRoleUser, - Content: messageContents, + ConversationID: conversation.ID, + Role: model.MessageRoleUser, + Content: messageContents, }, } @@ -56,5 +56,6 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { }, } + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 7e30d47..8e9411b 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -22,21 +22,23 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { messages := []model.Message{ { - Role: model.MessageRoleSystem, - Content: getSystemPrompt(ctx), + Role: model.MessageRoleSystem, + Content: ctx.GetSystemPrompt(), }, { - Role: model.MessageRoleUser, + Role: model.MessageRoleUser, Content: message, }, } - _, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil) + _, err := cmdutil.Prompt(ctx, messages, nil) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) } return nil }, } + + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go index d923aaa..25292cb 100644 --- a/pkg/cmd/reply.go +++ b/pkg/cmd/reply.go @@ -45,5 +45,7 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } + + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go index 9604830..f0e2ba0 100644 --- a/pkg/cmd/retry.go +++ b/pkg/cmd/retry.go @@ -54,5 +54,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } + + applyPromptFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 5bd7205..a21144c 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -13,9 +13,9 @@ import ( "github.com/charmbracelet/lipgloss" ) -// fetchAndShowCompletion prompts the LLM with the given messages and streams -// the response to stdout. Returns all model reply messages. -func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { +// Prompt prompts the configured the configured model and streams the response +// to stdout. Returns all model reply messages. +func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { content := make(chan string) // receives the reponse from LLM defer close(content) @@ -109,7 +109,7 @@ func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist } } - _, err = FetchAndShowCompletion(ctx, allMessages, replyCallback) + _, err = Prompt(ctx, allMessages, replyCallback) if err != nil { lmcli.Fatal("Error fetching LLM response: %v\n", err) } diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 0f814e6..6a5b4c0 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -10,17 +10,20 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic" "git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" + "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" "gorm.io/gorm" ) type Context struct { - Config *Config + Config *Config // may be updated at runtime Store ConversationStore Chroma *tty.ChromaHighlighter EnabledTools []model.Tool + + SystemPromptFile string } func NewContext() (*Context, error) { @@ -50,7 +53,7 @@ func NewContext() (*Context, error) { } } - return &Context{config, store, chroma, enabledTools}, nil + return &Context{config, store, chroma, enabledTools, ""}, nil } func (c *Context) GetModels() (models []string) { @@ -96,6 +99,17 @@ func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionCl return nil, fmt.Errorf("unknown model: %s", model) } +func (c *Context) GetSystemPrompt() string { + if c.SystemPromptFile != "" { + content, err := util.ReadFileContents(c.SystemPromptFile) + if err != nil { + Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err) + } + return content + } + return *c.Config.Defaults.SystemPrompt +} + func configDir() string { var configDir string diff --git a/pkg/lmcli/model/conversation.go b/pkg/lmcli/model/conversation.go index d08bd3e..3aa1516 100644 --- a/pkg/lmcli/model/conversation.go +++ b/pkg/lmcli/model/conversation.go @@ -32,13 +32,13 @@ type Conversation struct { } type RequestParameters struct { - Model string + Model string + MaxTokens int Temperature float32 TopP float32 - SystemPrompt string - ToolBag []Tool + ToolBag []Tool } func (m *MessageRole) IsAssistant() bool { diff --git a/pkg/lmcli/provider/anthropic/anthropic.go b/pkg/lmcli/provider/anthropic/anthropic.go index 8d951ca..706f449 100644 --- a/pkg/lmcli/provider/anthropic/anthropic.go +++ b/pkg/lmcli/provider/anthropic/anthropic.go @@ -19,7 +19,6 @@ func buildRequest(params model.RequestParameters, messages []model.Message) Requ requestBody := Request{ Model: params.Model, Messages: make([]Message, len(messages)), - System: params.SystemPrompt, MaxTokens: params.MaxTokens, Temperature: params.Temperature, Stream: false,