Compare commits

..

5 Commits

5 changed files with 84 additions and 26 deletions

3
go.mod
View File

@ -4,6 +4,7 @@ go 1.21
require ( require (
github.com/go-yaml/yaml v2.1.0+incompatible github.com/go-yaml/yaml v2.1.0+incompatible
github.com/gookit/color v1.5.4
github.com/sashabaranov/go-openai v1.16.0 github.com/sashabaranov/go-openai v1.16.0
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/sqids/sqids-go v0.4.1 github.com/sqids/sqids-go v0.4.1
@ -18,6 +19,8 @@ require (
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
golang.org/x/sys v0.14.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect gopkg.in/yaml.v2 v2.2.2 // indirect
) )

13
go.sum
View File

@ -1,7 +1,11 @@
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
@ -15,6 +19,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
@ -26,11 +32,18 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0= gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=

View File

@ -9,10 +9,18 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
// TODO: allow setting with flag var (
const MAX_TOKENS = 256 maxTokens int
model string
)
func init() { func init() {
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
for _, cmd := range inputCmds {
cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens")
cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use")
}
rootCmd.AddCommand( rootCmd.AddCommand(
lsCmd, lsCmd,
newCmd, newCmd,
@ -197,11 +205,7 @@ var viewCmd = &cobra.Command{
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title) Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
} }
l := len(messages) RenderConversation(messages, false)
for i, message := range messages {
message.RenderTTY(i < l-1)
}
fmt.Println()
}, },
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
compMode := cobra.ShellCompDirectiveNoFileComp compMode := cobra.ShellCompDirectiveNoFileComp
@ -236,6 +240,9 @@ var replyCmd = &cobra.Command{
} }
messageContents, err := InputFromEditor("# How would you like to reply?\n", "reply.*.md") messageContents, err := InputFromEditor("# How would you like to reply?\n", "reply.*.md")
if messageContents == "" {
Fatal("No reply was provided.\n")
}
userReply := Message{ userReply := Message{
ConversationID: conversation.ID, ConversationID: conversation.ID,
@ -249,15 +256,13 @@ var replyCmd = &cobra.Command{
} }
messages = append(messages, userReply) messages = append(messages, userReply)
for _, message := range messages {
message.RenderTTY(true)
}
RenderConversation(messages, true)
assistantReply := Message{ assistantReply := Message{
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: "assistant", Role: "assistant",
} }
assistantReply.RenderTTY(false) assistantReply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) response := make(chan string)
@ -265,7 +270,7 @@ var replyCmd = &cobra.Command{
response <- HandleDelayedResponse(receiver) response <- HandleDelayedResponse(receiver)
}() }()
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }
@ -329,15 +334,12 @@ var newCmd = &cobra.Command{
} }
} }
for _, message := range messages { RenderConversation(messages, true)
message.RenderTTY(true)
}
reply := Message{ reply := Message{
ConversationID: conversation.ID, ConversationID: conversation.ID,
Role: "assistant", Role: "assistant",
} }
reply.RenderTTY(false) reply.RenderTTY()
receiver := make(chan string) receiver := make(chan string)
response := make(chan string) response := make(chan string)
@ -345,7 +347,7 @@ var newCmd = &cobra.Command{
response <- HandleDelayedResponse(receiver) response <- HandleDelayedResponse(receiver)
}() }()
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }
@ -394,7 +396,7 @@ var promptCmd = &cobra.Command{
receiver := make(chan string) receiver := make(chan string)
go HandleDelayedResponse(receiver) go HandleDelayedResponse(receiver)
err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver) err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
} }

View File

@ -12,6 +12,7 @@ type Config struct {
OpenAI struct { OpenAI struct {
APIKey string `yaml:"apiKey"` APIKey string `yaml:"apiKey"`
DefaultModel string `yaml:"defaultModel"` DefaultModel string `yaml:"defaultModel"`
DefaultMaxLength int `yaml:"defaultMaxLength"`
} `yaml:"openai"` } `yaml:"openai"`
} }

View File

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/gookit/color"
) )
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is // ShowWaitAnimation "draws" an animated ellipses to stdout until something is
@ -58,12 +60,49 @@ func HandleDelayedResponse(response chan string) string {
return sb.String() return sb.String()
} }
func (m *Message) RenderTTY(paddingDown bool) { // RenderConversation renders the given messages, with optional space for a
fmt.Printf("<%s>\n\n", m.FriendlyRole()) // subsequent message. spaceForResponse controls how many newlines are printed
if m.OriginalContent != "" { // after the final message (1 newline if false, 2 if true)
fmt.Print(m.OriginalContent) func RenderConversation(messages []Message, spaceForResponse bool) {
l := len(messages)
for i, message := range messages {
message.RenderTTY()
if i < l-1 || spaceForResponse {
// print an additional space before the next message
fmt.Println()
} }
if paddingDown { }
fmt.Print("\n\n") }
func (m *Message) RenderTTY() {
var messageAge string
if m.CreatedAt.IsZero() {
messageAge = "now"
} else {
now := time.Now()
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
}
var roleStyle color.Style
switch m.Role {
case "system":
roleStyle = color.Style{color.HiRed}
case "user":
roleStyle = color.Style{color.HiGreen}
case "assistant":
roleStyle = color.Style{color.HiBlue}
default:
roleStyle = color.Style{color.FgWhite}
}
roleStyle.Add(color.Bold)
headerColor := color.FgYellow
separator := headerColor.Sprint("===")
timestamp := headerColor.Sprint(messageAge)
role := roleStyle.Sprint(m.FriendlyRole())
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
if m.OriginalContent != "" {
fmt.Println(m.OriginalContent)
} }
} }