diff --git a/config.go b/config.go new file mode 100644 index 0000000..71e94a8 --- /dev/null +++ b/config.go @@ -0,0 +1,64 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/go-yaml/yaml" +) + +type Config struct { + OpenAI struct { + APIKey string `yaml:"apiKey"` + } `yaml:"openai"` +} + +func getConfigDir() string { + var configDir string + + xdgConfigHome := os.Getenv("XDG_CONFIG_HOME") + if xdgConfigHome != "" { + configDir = filepath.Join(xdgConfigHome, "lmcli") + } else { + userHomeDir, _ := os.UserHomeDir() + configDir = filepath.Join(userHomeDir, ".config/lmcli") + } + + os.MkdirAll(configDir, 0755) + return configDir +} + +func LoadConfig() *Config { + configFile := filepath.Join(getConfigDir(), "config.yaml") + + configBytes, err := os.ReadFile(configFile) + if os.IsNotExist(err) { + defaultConfig := &Config{} + defaultConfig.OpenAI.APIKey = "your_key_here" + + + file, err := os.Create(configFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not open config file for writing: %v", err) + os.Exit(1) + } + + fmt.Printf("Writing default configuration to: %s\n", configFile) + + bytes, _ := yaml.Marshal(defaultConfig) + + _, err = file.Write(bytes) + if err != nil { + fmt.Fprintf(os.Stderr, "Could not save default configuratoin: %v", err) + os.Exit(1) + } + } else if err != nil { + fmt.Fprintf(os.Stderr, "Could not read config file: %v", err) + os.Exit(1) + } + + config := &Config{} + yaml.Unmarshal(configBytes, config) + return config +} diff --git a/go.mod b/go.mod index ff03af0..93837b7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module git.mlow.ca/mlow/lmcli go 1.19 require ( + github.com/go-yaml/yaml v2.1.0+incompatible github.com/sashabaranov/go-openai v1.16.0 github.com/spf13/cobra v1.7.0 github.com/sqids/sqids-go v0.4.1 @@ -14,6 +15,9 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/spf13/pflag v1.0.5 // indirect + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + gopkg.in/yaml.v2 v2.2.2 // indirect ) diff --git a/go.sum b/go.sum index 687200b..2704448 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,22 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= +github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sashabaranov/go-openai v1.16.0 h1:34W6WV84ey6OpW0p2UewZkdMu82AxGC+BzpU6iiauRw= github.com/sashabaranov/go-openai v1.16.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= @@ -17,6 +27,10 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= diff --git a/main.go b/main.go index cf43819..6881c00 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" ) +var config = LoadConfig() var store, storeError = InitializeStore() func checkStore() { diff --git a/openai.go b/openai.go index 9f399f2..b142de9 100644 --- a/openai.go +++ b/openai.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "io" - "os" openai "github.com/sashabaranov/go-openai" ) @@ -29,7 +28,7 @@ func CreateChatCompletionRequest(messages []Message) (openai.ChatCompletionReque // CreateChatCompletion accepts a slice of Message and returns the response // of the Large Language Model. func CreateChatCompletion(system string, messages []Message) (string, error) { - client := openai.NewClient(os.Getenv("OPENAI_APIKEY")) + client := openai.NewClient(config.OpenAI.APIKey) resp, err := client.CreateChatCompletion( context.Background(), CreateChatCompletionRequest(messages), @@ -43,7 +42,7 @@ func CreateChatCompletion(system string, messages []Message) (string, error) { } func CreateChatCompletionStream(system string, messages []Message, output io.Writer) (error) { - client := openai.NewClient(os.Getenv("OPENAI_APIKEY")) + client := openai.NewClient(config.OpenAI.APIKey) ctx := context.Background() req := CreateChatCompletionRequest(messages)