diff --git a/pkg/cmd/chat.go b/pkg/cmd/chat.go index 527529d..be93369 100644 --- a/pkg/cmd/chat.go +++ b/pkg/cmd/chat.go @@ -15,6 +15,10 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command { Short: "Open the chat interface", Long: `Open the chat interface, optionally on a given conversation.`, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } shortname := "" if len(args) == 1 { shortname = args[0] @@ -25,7 +29,7 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command { return err } } - err := tui.Launch(ctx, shortname) + err = tui.Launch(ctx, shortname) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) } @@ -39,6 +43,6 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 593fff2..764cea0 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -1,6 +1,8 @@ package cmd import ( + "fmt" + "slices" "strings" "git.mlow.ca/mlow/lmcli/pkg/lmcli" @@ -37,27 +39,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command { return root } -func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) { +func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) { f := cmd.Flags() + // -m, --model f.StringVarP( - ctx.Config.Defaults.Model, - "model", "m", - *ctx.Config.Defaults.Model, - "The model to generate a response with", + ctx.Config.Defaults.Model, "model", "m", + *ctx.Config.Defaults.Model, "Which model to generate a response with", ) cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { return ctx.GetModels(), cobra.ShellCompDirectiveDefault }) + // --max-length f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens") + // --temperature f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature") + // --system-prompt f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt") - f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt") + // --system-prompt-file + f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt") cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") } +func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error { + f := cmd.Flags() + + model, err := f.GetString("model") + if err != nil { + return fmt.Errorf("Error parsing --model: %w", err) + } + if !slices.Contains(ctx.GetModels(), model) { + return fmt.Errorf("Unknown model: %s", model) + } + return nil +} + // inputFromArgsOrEditor returns either the provided input from the args slice // (joined with spaces), or if len(args) is 0, opens an editor and returns // whatever input was provided there. placeholder is a string which populates diff --git a/pkg/cmd/continue.go b/pkg/cmd/continue.go index 73503d0..965efd1 100644 --- a/pkg/cmd/continue.go +++ b/pkg/cmd/continue.go @@ -23,6 +23,11 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } + shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) @@ -68,6 +73,6 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command { return ctx.Store.ConversationShortNameCompletions(toComplete), compMode }, } - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/new.go b/pkg/cmd/new.go index aca7a9e..45a2002 100644 --- a/pkg/cmd/new.go +++ b/pkg/cmd/new.go @@ -15,6 +15,11 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { Short: "Start a new conversation", Long: `Start a new conversation with the Large Language Model.`, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } + input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "") if input == "" { return fmt.Errorf("No message was provided.") @@ -22,8 +27,7 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { var messages []api.Message - // TODO: probably just make this part of the conversation - system := ctx.GetSystemPrompt() + system := ctx.Config.GetSystemPrompt() if system != "" { messages = append(messages, api.Message{ Role: api.MessageRoleSystem, @@ -57,6 +61,6 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command { }, } - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/prompt.go b/pkg/cmd/prompt.go index 8664c91..9d17959 100644 --- a/pkg/cmd/prompt.go +++ b/pkg/cmd/prompt.go @@ -15,6 +15,11 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { Short: "Do a one-shot prompt", Long: `Prompt the Large Language Model and get a response.`, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } + input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "") if input == "" { return fmt.Errorf("No message was provided.") @@ -22,8 +27,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { var messages []api.Message - // TODO: stop supplying system prompt as a message - system := ctx.GetSystemPrompt() + system := ctx.Config.GetSystemPrompt() if system != "" { messages = append(messages, api.Message{ Role: api.MessageRoleSystem, @@ -36,7 +40,7 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { Content: input, }) - _, err := cmdutil.Prompt(ctx, messages, nil) + _, err = cmdutil.Prompt(ctx, messages, nil) if err != nil { return fmt.Errorf("Error fetching LLM response: %v", err) } @@ -44,6 +48,6 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command { }, } - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/reply.go b/pkg/cmd/reply.go index 6338566..a0c0a65 100644 --- a/pkg/cmd/reply.go +++ b/pkg/cmd/reply.go @@ -22,6 +22,11 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } + shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) @@ -45,6 +50,6 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command { }, } - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/cmd/retry.go b/pkg/cmd/retry.go index d88dd87..e2ba866 100644 --- a/pkg/cmd/retry.go +++ b/pkg/cmd/retry.go @@ -22,6 +22,11 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { return nil }, RunE: func(cmd *cobra.Command, args []string) error { + err := validateGenerationFlags(ctx, cmd) + if err != nil { + return err + } + shortName := args[0] conversation := cmdutil.LookupConversation(ctx, shortName) @@ -68,6 +73,6 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command { cmd.Flags().Int("offset", 0, "Offset from the last message to retry from.") - applyPromptFlags(ctx, cmd) + applyGenerationFlags(ctx, cmd) return cmd } diff --git a/pkg/lmcli/config.go b/pkg/lmcli/config.go index d266379..89190e1 100644 --- a/pkg/lmcli/config.go +++ b/pkg/lmcli/config.go @@ -10,10 +10,11 @@ import ( type Config struct { Defaults *struct { - SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` - MaxTokens *int `yaml:"maxTokens" default:"256"` - Temperature *float32 `yaml:"temperature" default:"0.7"` - Model *string `yaml:"model" default:"gpt-4"` + SystemPromptFile string `yaml:"systemPromptFile,omitempty"` + SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` + MaxTokens *int `yaml:"maxTokens" default:"256"` + Temperature *float32 `yaml:"temperature" default:"0.2"` + Model *string `yaml:"model" default:"gpt-4"` } `yaml:"defaults"` Conversations *struct { TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"` @@ -22,10 +23,10 @@ type Config struct { EnabledTools []string `yaml:"enabledTools"` } `yaml:"tools"` Providers []*struct { - Name *string `yaml:"name"` + Name *string `yaml:"name,omitempty"` Kind *string `yaml:"kind"` - BaseURL *string `yaml:"baseUrl"` - APIKey *string `yaml:"apiKey"` + BaseURL *string `yaml:"baseUrl,omitempty"` + APIKey *string `yaml:"apiKey,omitempty"` Models *[]string `yaml:"models"` } `yaml:"providers"` Chroma *struct { @@ -68,3 +69,17 @@ func NewConfig(configFile string) (*Config, error) { return c, nil } + +func (c *Config) GetSystemPrompt() string { + if c.Defaults.SystemPromptFile != "" { + content, err := util.ReadFileContents(c.Defaults.SystemPromptFile) + if err != nil { + Fatal("Could not read file contents at %s: %v\n", c.Defaults.SystemPromptFile, err) + } + return content + } + if c.Defaults.SystemPrompt == nil { + return "" + } + return *c.Defaults.SystemPrompt +} diff --git a/pkg/lmcli/lmcli.go b/pkg/lmcli/lmcli.go index 651088f..3ee8c3e 100644 --- a/pkg/lmcli/lmcli.go +++ b/pkg/lmcli/lmcli.go @@ -12,7 +12,6 @@ import ( "git.mlow.ca/mlow/lmcli/pkg/api/provider/google" "git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama" "git.mlow.ca/mlow/lmcli/pkg/api/provider/openai" - "git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util/tty" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -24,8 +23,6 @@ type Context struct { Chroma *tty.ChromaHighlighter EnabledTools []api.ToolSpec - - SystemPromptFile string } func NewContext() (*Context, error) { @@ -57,7 +54,7 @@ func NewContext() (*Context, error) { } } - return &Context{config, store, chroma, enabledTools, ""}, nil + return &Context{config, store, chroma, enabledTools}, nil } func (c *Context) GetModels() (models []string) { @@ -139,17 +136,6 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv return "", nil, fmt.Errorf("unknown model: %s", model) } -func (c *Context) GetSystemPrompt() string { - if c.SystemPromptFile != "" { - content, err := util.ReadFileContents(c.SystemPromptFile) - if err != nil { - Fatal("Could not read file contents at %s: %v\n", c.SystemPromptFile, err) - } - return content - } - return *c.Config.Defaults.SystemPrompt -} - func configDir() string { var configDir string diff --git a/pkg/tui/views/chat/chat.go b/pkg/tui/views/chat/chat.go index cdd5223..9984448 100644 --- a/pkg/tui/views/chat/chat.go +++ b/pkg/tui/views/chat/chat.go @@ -143,7 +143,7 @@ func Chat(shared shared.Shared) Model { m.replyCursor.SetChar(" ") m.replyCursor.Focus() - system := shared.Ctx.GetSystemPrompt() + system := shared.Ctx.Config.GetSystemPrompt() if system != "" { m.messages = []api.Message{{ Role: api.MessageRoleSystem, diff --git a/pkg/util/util.go b/pkg/util/util.go index 8673aef..91b4feb 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -21,7 +21,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string msgFile, _ := os.CreateTemp("/tmp", pattern) defer os.Remove(msgFile.Name()) - os.WriteFile(msgFile.Name(), []byte(placeholder + content), os.ModeAppend) + os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend) editor := os.Getenv("EDITOR") if editor == "" { @@ -137,8 +137,8 @@ func SetStructDefaults(data interface{}) bool { } // Get the "default" struct tag - defaultTag := v.Type().Field(i).Tag.Get("default") - if defaultTag == "" { + defaultTag, ok := v.Type().Field(i).Tag.Lookup("default") + if (!ok) { continue }