Compare commits

...

2 Commits

5 changed files with 90 additions and 33 deletions

View File

@ -2,8 +2,8 @@ package cli
import ( import (
"fmt" "fmt"
"os"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
@ -66,7 +66,7 @@ var replyCmd = &cobra.Command{
var newCmd = &cobra.Command{ var newCmd = &cobra.Command{
Use: "new", Use: "new",
Short: "Start a new conversation", Short: "Start a new conversation",
Long: `Start a new conversation with the Large Language Model.`, Long: `Start a new conversation with the Large Language Model.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
messageContents, err := InputFromEditor("# What would you like to say?\n", "message.*.md") messageContents, err := InputFromEditor("# What would you like to say?\n", "message.*.md")
if err != nil { if err != nil {
@ -83,12 +83,14 @@ var newCmd = &cobra.Command{
messages := []Message{ messages := []Message{
{ {
Role: "user",
OriginalContent: messageContents, OriginalContent: messageContents,
Role: "user",
}, },
} }
err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout) receiver := make(chan string)
go HandleDelayedResponse(receiver)
err = CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
return return
@ -99,9 +101,9 @@ var newCmd = &cobra.Command{
} }
var promptCmd = &cobra.Command{ var promptCmd = &cobra.Command{
Use: "prompt", Use: "prompt",
Short: "Do a one-shot prompt", Short: "Do a one-shot prompt",
Long: `Prompt the Large Language Model and get a response.`, Long: `Prompt the Large Language Model and get a response.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
message := strings.Join(args, " ") message := strings.Join(args, " ")
if len(strings.Trim(message, " \t\n")) == 0 { if len(strings.Trim(message, " \t\n")) == 0 {
@ -111,12 +113,14 @@ var promptCmd = &cobra.Command{
messages := []Message{ messages := []Message{
{ {
Role: "user",
OriginalContent: message, OriginalContent: message,
Role: "user",
}, },
} }
err := CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout) receiver := make(chan string)
go HandleDelayedResponse(receiver)
err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
if err != nil { if err != nil {
Fatal("%v\n", err) Fatal("%v\n", err)
return return
@ -128,5 +132,5 @@ var promptCmd = &cobra.Command{
func NewRootCmd() *cobra.Command { func NewRootCmd() *cobra.Command {
rootCmd.AddCommand(newCmd, promptCmd) rootCmd.AddCommand(newCmd, promptCmd)
return rootCmd; return rootCmd
} }

View File

@ -37,7 +37,6 @@ func InitializeConfig() *Config {
defaultConfig := &Config{} defaultConfig := &Config{}
defaultConfig.OpenAI.APIKey = "your_key_here" defaultConfig.OpenAI.APIKey = "your_key_here"
file, err := os.Create(configFile) file, err := os.Create(configFile)
if err != nil { if err != nil {
Fatal("Could not open config file for writing: %v", err) Fatal("Could not open config file for writing: %v", err)

View File

@ -3,22 +3,22 @@ package cli
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"io" "io"
openai "github.com/sashabaranov/go-openai" openai "github.com/sashabaranov/go-openai"
) )
func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) { func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest {
chatCompletionMessages := []openai.ChatCompletionMessage{ chatCompletionMessages := []openai.ChatCompletionMessage{
{ {
Role: "system", Role: "system",
Content: system, Content: system,
}, },
} }
for _, m := range(messages) { for _, m := range messages {
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{ chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
Role: m.Role, Role: m.Role,
Content: m.OriginalContent, Content: m.OriginalContent,
}) })
} }
@ -26,8 +26,8 @@ func CreateChatCompletionRequest(system string, messages []Message) (*openai.Cha
return &openai.ChatCompletionRequest{ return &openai.ChatCompletionRequest{
Model: openai.GPT4, Model: openai.GPT4,
MaxTokens: 256, MaxTokens: 256,
Messages: chatCompletionMessages, Messages: chatCompletionMessages,
Stream: true, Stream: true,
} }
} }
@ -47,13 +47,17 @@ func CreateChatCompletion(system string, messages []Message) (string, error) {
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error { // CreateChatCompletionStream submits an streaming Chat Completion API request
// and sends the received data to the output channel.
func CreateChatCompletionStream(system string, messages []Message, output chan string) error {
client := openai.NewClient(config.OpenAI.APIKey) client := openai.NewClient(config.OpenAI.APIKey)
ctx := context.Background() ctx := context.Background()
req := CreateChatCompletionRequest(system, messages) req := CreateChatCompletionRequest(system, messages)
req.Stream = true req.Stream = true
defer close(output)
stream, err := client.CreateChatCompletionStream(ctx, *req) stream, err := client.CreateChatCompletionStream(ctx, *req)
if err != nil { if err != nil {
return err return err
@ -68,10 +72,9 @@ func CreateChatCompletionStream(system string, messages []Message, output io.Wri
} }
if err != nil { if err != nil {
//fmt.Printf("\nStream error: %v\n", err)
return err return err
} }
fmt.Fprint(output, response.Choices[0].Delta.Content) output <- response.Choices[0].Delta.Content
} }
} }

View File

@ -4,33 +4,33 @@ import (
"database/sql" "database/sql"
"os" "os"
"path/filepath" "path/filepath"
"gorm.io/gorm"
"gorm.io/driver/sqlite"
sqids "github.com/sqids/sqids-go" sqids "github.com/sqids/sqids-go"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
) )
type Store struct { type Store struct {
db *gorm.DB db *gorm.DB
sqids *sqids.Sqids sqids *sqids.Sqids
} }
type Message struct { type Message struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ConversationID uint `gorm:"foreignKey:ConversationID"` ConversationID uint `gorm:"foreignKey:ConversationID"`
Conversation Conversation Conversation Conversation
OriginalContent string OriginalContent string
Role string // 'user' or 'assistant' Role string // 'user' or 'assistant'
} }
type Conversation struct { type Conversation struct {
ID uint `gorm:"primaryKey"` ID uint `gorm:"primaryKey"`
ShortName sql.NullString ShortName sql.NullString
Title string Title string
} }
func getDataDir() string { func getDataDir() string {
var dataDir string; var dataDir string
xdgDataHome := os.Getenv("XDG_DATA_HOME") xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" { if xdgDataHome != "" {
@ -57,7 +57,7 @@ func InitializeStore() *Store {
&Message{}, &Message{},
} }
for _, x := range(models) { for _, x := range models {
err := db.AutoMigrate(x) err := db.AutoMigrate(x)
if err != nil { if err != nil {
Fatal("Could not perform database migrations: %v\n", err) Fatal("Could not perform database migrations: %v\n", err)
@ -76,7 +76,7 @@ func (s *Store) SaveConversation(conversation *Conversation) error {
} }
if !conversation.ShortName.Valid { if !conversation.ShortName.Valid {
shortName, _ := s.sqids.Encode([]uint64{ uint64(conversation.ID) }) shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
conversation.ShortName = sql.NullString{String: shortName, Valid: true} conversation.ShortName = sql.NullString{String: shortName, Valid: true}
err = s.db.Save(&conversation).Error err = s.db.Save(&conversation).Error
} }

51
pkg/cli/tty.go Normal file
View File

@ -0,0 +1,51 @@
package cli
import (
"fmt"
"time"
)
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// noftify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
animationStep := 0
for {
select {
case _ = <-signal:
fmt.Print("\r")
signal <- ""
return
default:
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\r")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
// content received on the response channel to stdout. Blocks until the channel
// is closed.
func HandleDelayedResponse(response chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range response {
if firstChunk {
waitSignal <- ""
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}