Compare commits
5 Commits
cac2a1e80c
...
6426b04e2c
Author | SHA1 | Date | |
---|---|---|---|
6426b04e2c | |||
965043c908 | |||
8bc8312154 | |||
681b52a55c | |||
22e0ff4115 |
3
go.mod
3
go.mod
@ -4,6 +4,7 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/gookit/color v1.5.4
|
||||
github.com/sashabaranov/go-openai v1.16.0
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/sqids/sqids-go v0.4.1
|
||||
@ -18,6 +19,8 @@ require (
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
golang.org/x/sys v0.14.0 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||
)
|
||||
|
13
go.sum
13
go.sum
@ -1,7 +1,11 @@
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
@ -15,6 +19,8 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@ -26,11 +32,18 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||
|
@ -9,10 +9,18 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// TODO: allow setting with flag
|
||||
const MAX_TOKENS = 256
|
||||
var (
|
||||
maxTokens int
|
||||
model string
|
||||
)
|
||||
|
||||
func init() {
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().IntVar(&maxTokens, "length", config.OpenAI.DefaultMaxLength, "Max response length in tokens")
|
||||
cmd.Flags().StringVar(&model, "model", config.OpenAI.DefaultModel, "The language model to use")
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(
|
||||
lsCmd,
|
||||
newCmd,
|
||||
@ -197,11 +205,7 @@ var viewCmd = &cobra.Command{
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
message.RenderTTY(i < l-1)
|
||||
}
|
||||
fmt.Println()
|
||||
RenderConversation(messages, false)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
@ -236,6 +240,9 @@ var replyCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
messageContents, err := InputFromEditor("# How would you like to reply?\n", "reply.*.md")
|
||||
if messageContents == "" {
|
||||
Fatal("No reply was provided.\n")
|
||||
}
|
||||
|
||||
userReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
@ -249,15 +256,13 @@ var replyCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
messages = append(messages, userReply)
|
||||
for _, message := range messages {
|
||||
message.RenderTTY(true)
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY(false)
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
@ -265,7 +270,7 @@ var replyCmd = &cobra.Command{
|
||||
response <- HandleDelayedResponse(receiver)
|
||||
}()
|
||||
|
||||
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver)
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
@ -329,15 +334,12 @@ var newCmd = &cobra.Command{
|
||||
}
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
message.RenderTTY(true)
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
reply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
reply.RenderTTY(false)
|
||||
reply.RenderTTY()
|
||||
|
||||
receiver := make(chan string)
|
||||
response := make(chan string)
|
||||
@ -345,7 +347,7 @@ var newCmd = &cobra.Command{
|
||||
response <- HandleDelayedResponse(receiver)
|
||||
}()
|
||||
|
||||
err = CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver)
|
||||
err = CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
@ -394,7 +396,7 @@ var promptCmd = &cobra.Command{
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err := CreateChatCompletionStream(config.OpenAI.DefaultModel, messages, MAX_TOKENS, receiver)
|
||||
err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ type Config struct {
|
||||
OpenAI struct {
|
||||
APIKey string `yaml:"apiKey"`
|
||||
DefaultModel string `yaml:"defaultModel"`
|
||||
DefaultMaxLength int `yaml:"defaultMaxLength"`
|
||||
} `yaml:"openai"`
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gookit/color"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
@ -58,12 +60,49 @@ func HandleDelayedResponse(response chan string) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m *Message) RenderTTY(paddingDown bool) {
|
||||
fmt.Printf("<%s>\n\n", m.FriendlyRole())
|
||||
if m.OriginalContent != "" {
|
||||
fmt.Print(m.OriginalContent)
|
||||
}
|
||||
if paddingDown {
|
||||
fmt.Print("\n\n")
|
||||
// RenderConversation renders the given messages, with optional space for a
|
||||
// subsequent message. spaceForResponse controls how many newlines are printed
|
||||
// after the final message (1 newline if false, 2 if true)
|
||||
func RenderConversation(messages []Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
message.RenderTTY()
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) RenderTTY() {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
var roleStyle color.Style
|
||||
switch m.Role {
|
||||
case "system":
|
||||
roleStyle = color.Style{color.HiRed}
|
||||
case "user":
|
||||
roleStyle = color.Style{color.HiGreen}
|
||||
case "assistant":
|
||||
roleStyle = color.Style{color.HiBlue}
|
||||
default:
|
||||
roleStyle = color.Style{color.FgWhite}
|
||||
}
|
||||
roleStyle.Add(color.Bold)
|
||||
|
||||
headerColor := color.FgYellow
|
||||
separator := headerColor.Sprint("===")
|
||||
timestamp := headerColor.Sprint(messageAge)
|
||||
role := roleStyle.Sprint(m.FriendlyRole())
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.OriginalContent != "" {
|
||||
fmt.Println(m.OriginalContent)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user