diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 6fc3db4..8ead74e 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -10,8 +10,10 @@ import ( ) var ( - maxTokens int - model string + maxTokens int + model string + systemPrompt string + systemPromptFile string ) func init() { @@ -19,6 +21,9 @@ func init() { for _, cmd := range inputCmds { cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens") cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use") + cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "The system prompt to use.") + cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file whose contents are used as the system prompt.") + cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file") } rootCmd.AddCommand( @@ -35,6 +40,17 @@ func Execute() error { return rootCmd.Execute() } +func SystemPrompt() string { + if systemPromptFile != "" { + content, err := FileContents(systemPromptFile) + if err != nil { + Fatal("Could not read file contents at %s: %v", systemPromptFile, err) + } + return content + } + return systemPrompt +} + // InputFromArgsOrEditor returns either the provided input from the args slice // (joined with spaces), or if len(args) is 0, opens an editor and returns // whatever input was provided there. placeholder is a string which populates @@ -327,12 +343,11 @@ var newCmd = &cobra.Command{ Fatal("Could not save new conversation: %v\n", err) } - const system = "You are a helpful assistant." messages := []Message{ { ConversationID: conversation.ID, Role: "system", - OriginalContent: system, + OriginalContent: SystemPrompt(), }, { ConversationID: conversation.ID, @@ -396,11 +411,10 @@ var promptCmd = &cobra.Command{ Fatal("No message was provided.\n") } - const system = "You are a helpful assistant." messages := []Message{ { Role: "system", - OriginalContent: system, + OriginalContent: SystemPrompt(), }, { Role: "user", diff --git a/pkg/cli/config.go b/pkg/cli/config.go index 620c429..c4108b0 100644 --- a/pkg/cli/config.go +++ b/pkg/cli/config.go @@ -9,6 +9,9 @@ import ( ) type Config struct { + ModelDefaults *struct { + SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."` + } `yaml:"modelDefaults"` OpenAI *struct { APIKey *string `yaml:"apiKey" default:"your_key_here"` DefaultModel *string `yaml:"defaultModel" default:"gpt-4"` diff --git a/pkg/cli/util.go b/pkg/cli/util.go index 11af566..5acc547 100644 --- a/pkg/cli/util.go +++ b/pkg/cli/util.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "os/exec" + "path/filepath" "reflect" "strconv" "strings" @@ -158,3 +159,15 @@ func SetStructDefaults(data interface{}) bool { } return changed } + +// FileContents returns the string contents of the given file. +// TODO: we should support retrieving the content (or an approximation of) +// non-text documents, e.g. PDFs. +func FileContents(file string) (string, error) { + path := filepath.Clean(file) + content, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.Trim(string(content), "\n\t "), nil +}