2024-02-21 21:55:38 -07:00
package cmd
import (
2024-06-22 22:47:47 -06:00
"fmt"
"slices"
2024-02-21 21:55:38 -07:00
"strings"
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/spf13/cobra"
)
func RootCmd ( ctx * lmcli . Context ) * cobra . Command {
var root = & cobra . Command {
Use : "lmcli <command> [flags]" ,
Long : ` lmcli - Large Language Model CLI ` ,
SilenceErrors : true ,
SilenceUsage : true ,
Run : func ( cmd * cobra . Command , args [ ] string ) {
cmd . Usage ( )
} ,
}
root . AddCommand (
2024-05-07 01:11:04 -06:00
ChatCmd ( ctx ) ,
ContinueCmd ( ctx ) ,
CloneCmd ( ctx ) ,
EditCmd ( ctx ) ,
ListCmd ( ctx ) ,
NewCmd ( ctx ) ,
PromptCmd ( ctx ) ,
RenameCmd ( ctx ) ,
ReplyCmd ( ctx ) ,
RetryCmd ( ctx ) ,
RemoveCmd ( ctx ) ,
ViewCmd ( ctx ) ,
2024-02-21 21:55:38 -07:00
)
return root
}
2024-06-22 22:47:47 -06:00
func applyGenerationFlags ( ctx * lmcli . Context , cmd * cobra . Command ) {
2024-05-07 01:11:04 -06:00
f := cmd . Flags ( )
2024-06-22 22:47:47 -06:00
// -m, --model
2024-05-07 01:11:04 -06:00
f . StringVarP (
2024-06-22 22:47:47 -06:00
ctx . Config . Defaults . Model , "model" , "m" ,
* ctx . Config . Defaults . Model , "Which model to generate a response with" ,
2024-05-07 01:11:04 -06:00
)
cmd . RegisterFlagCompletionFunc ( "model" , func ( * cobra . Command , [ ] string , string ) ( [ ] string , cobra . ShellCompDirective ) {
return ctx . GetModels ( ) , cobra . ShellCompDirectiveDefault
} )
2024-06-23 12:57:08 -06:00
// -a, --agent
f . StringVarP ( & ctx . Config . Defaults . Agent , "agent" , "a" , ctx . Config . Defaults . Agent , "Which agent to interact with" )
cmd . RegisterFlagCompletionFunc ( "agent" , func ( * cobra . Command , [ ] string , string ) ( [ ] string , cobra . ShellCompDirective ) {
return ctx . GetAgents ( ) , cobra . ShellCompDirectiveDefault
} )
2024-06-22 22:47:47 -06:00
// --max-length
2024-05-07 01:11:04 -06:00
f . IntVar ( ctx . Config . Defaults . MaxTokens , "max-length" , * ctx . Config . Defaults . MaxTokens , "Maximum response tokens" )
2024-06-22 22:47:47 -06:00
// --temperature
2024-05-07 01:11:04 -06:00
f . Float32VarP ( ctx . Config . Defaults . Temperature , "temperature" , "t" , * ctx . Config . Defaults . Temperature , "Sampling temperature" )
2024-06-22 22:47:47 -06:00
// --system-prompt
2024-06-23 10:02:26 -06:00
f . StringVar ( & ctx . Config . Defaults . SystemPrompt , "system-prompt" , ctx . Config . Defaults . SystemPrompt , "System prompt" )
2024-06-22 22:47:47 -06:00
// --system-prompt-file
f . StringVar ( & ctx . Config . Defaults . SystemPromptFile , "system-prompt-file" , ctx . Config . Defaults . SystemPromptFile , "A path to a file containing the system prompt" )
2024-05-07 01:11:04 -06:00
cmd . MarkFlagsMutuallyExclusive ( "system-prompt" , "system-prompt-file" )
2024-02-21 21:55:38 -07:00
}
2024-06-22 22:47:47 -06:00
func validateGenerationFlags ( ctx * lmcli . Context , cmd * cobra . Command ) error {
f := cmd . Flags ( )
model , err := f . GetString ( "model" )
if err != nil {
return fmt . Errorf ( "Error parsing --model: %w" , err )
}
2024-06-23 12:57:08 -06:00
if model != "" && ! slices . Contains ( ctx . GetModels ( ) , model ) {
2024-06-22 22:47:47 -06:00
return fmt . Errorf ( "Unknown model: %s" , model )
}
2024-06-23 12:57:08 -06:00
agent , err := f . GetString ( "agent" )
if err != nil {
return fmt . Errorf ( "Error parsing --agent: %w" , err )
}
if agent != "" && ! slices . Contains ( ctx . GetAgents ( ) , agent ) {
return fmt . Errorf ( "Unknown agent: %s" , agent )
}
2024-06-22 22:47:47 -06:00
return nil
}
2024-02-21 21:55:38 -07:00
// inputFromArgsOrEditor returns either the provided input from the args slice
// (joined with spaces), or if len(args) is 0, opens an editor and returns
// whatever input was provided there. placeholder is a string which populates
// the editor and gets stripped from the final output.
func inputFromArgsOrEditor ( args [ ] string , placeholder string , existingMessage string ) ( message string ) {
var err error
if len ( args ) == 0 {
message , err = util . InputFromEditor ( placeholder , "message.*.md" , existingMessage )
if err != nil {
lmcli . Fatal ( "Failed to get input: %v\n" , err )
}
} else {
message = strings . Join ( args , " " )
}
message = strings . Trim ( message , " \t\n" )
return
}