lmcli/pkg/cmd/cmd.go

109 lines
3.3 KiB
Go
Raw Normal View History

package cmd
import (
"fmt"
"slices"
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/spf13/cobra"
)
func RootCmd(ctx *lmcli.Context) *cobra.Command {
var root = &cobra.Command{
Use: "lmcli <command> [flags]",
Long: `lmcli - Large Language Model CLI`,
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
cmd.Usage()
},
}
root.AddCommand(
ChatCmd(ctx),
ContinueCmd(ctx),
CloneCmd(ctx),
EditCmd(ctx),
ListCmd(ctx),
NewCmd(ctx),
PromptCmd(ctx),
RenameCmd(ctx),
ReplyCmd(ctx),
RetryCmd(ctx),
RemoveCmd(ctx),
ViewCmd(ctx),
)
return root
}
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
f := cmd.Flags()
// -m, --model
f.StringVarP(
ctx.Config.Defaults.Model, "model", "m",
*ctx.Config.Defaults.Model, "Which model to generate a response with",
)
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
})
// -a, --agent
f.StringVarP(&ctx.Config.Defaults.Agent, "agent", "a", ctx.Config.Defaults.Agent, "Which agent to interact with")
cmd.RegisterFlagCompletionFunc("agent", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
return ctx.GetAgents(), cobra.ShellCompDirectiveDefault
})
// --max-length
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
// --temperature
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
// --system-prompt
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
// --system-prompt-file
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
}
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
f := cmd.Flags()
model, err := f.GetString("model")
if err != nil {
return fmt.Errorf("Error parsing --model: %w", err)
}
if model != "" && !slices.Contains(ctx.GetModels(), model) {
return fmt.Errorf("Unknown model: %s", model)
}
agent, err := f.GetString("agent")
if err != nil {
return fmt.Errorf("Error parsing --agent: %w", err)
}
if agent != "" && agent != "none" && !slices.Contains(ctx.GetAgents(), agent) {
return fmt.Errorf("Unknown agent: %s", agent)
}
return nil
}
// inputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates
// the editor and gets stripped from the final output.
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
var err error
if len(args) == 0 {
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
if err != nil {
lmcli.Fatal("Failed to get input: %v\n", err)
}
} else {
message = strings.Join(args, " ")
}
message = strings.Trim(message, " \t\n")
return
}