Private
Public Access
1
0

Add validation to command line flags + update system prompt handling

Renamed `applyPromptFlags` to `applyGenerationFlags` and added
`validateGenerationFlags`
This commit is contained in:
2024-06-23 04:47:47 +00:00
parent 677cfcfebf
commit f89cc7b410
11 changed files with 90 additions and 44 deletions

View File

@@ -1,6 +1,8 @@
package cmd
import (
"fmt"
"slices"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
@@ -37,27 +39,43 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
return root
}
func applyPromptFlags(ctx *lmcli.Context, cmd *cobra.Command) {
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
f := cmd.Flags()
// -m, --model
f.StringVarP(
ctx.Config.Defaults.Model,
"model", "m",
*ctx.Config.Defaults.Model,
"The model to generate a response with",
ctx.Config.Defaults.Model, "model", "m",
*ctx.Config.Defaults.Model, "Which model to generate a response with",
)
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
// --max-length
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
// --temperature
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt
f.StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
f.StringVar(&ctx.SystemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
// --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
f := cmd.Flags()
model, err := f.GetString("model")
if err != nil {
return fmt.Errorf("Error parsing --model: %w", err)
}
if !slices.Contains(ctx.GetModels(), model) {
return fmt.Errorf("Unknown model: %s", model)
}
return nil
}
// inputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates