Compare commits
No commits in common. "a28a7a00540706bfcc943e79fb8d290a86d0ffc5" and "4590f1db38b83aeb79b13378fea7e91de05500fd" have entirely different histories.
a28a7a0054
...
4590f1db38
@ -2,8 +2,8 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -66,7 +66,7 @@ var replyCmd = &cobra.Command{
|
||||
var newCmd = &cobra.Command{
|
||||
Use: "new",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
messageContents, err := InputFromEditor("# What would you like to say?\n", "message.*.md")
|
||||
if err != nil {
|
||||
@ -83,14 +83,12 @@ var newCmd = &cobra.Command{
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
Role: "user",
|
||||
},
|
||||
}
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err = CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||
err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
return
|
||||
@ -101,9 +99,9 @@ var newCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
var promptCmd = &cobra.Command{
|
||||
Use: "prompt",
|
||||
Use: "prompt",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
message := strings.Join(args, " ")
|
||||
if len(strings.Trim(message, " \t\n")) == 0 {
|
||||
@ -113,14 +111,12 @@ var promptCmd = &cobra.Command{
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: message,
|
||||
Role: "user",
|
||||
},
|
||||
}
|
||||
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||
err := CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
return
|
||||
@ -132,5 +128,5 @@ var promptCmd = &cobra.Command{
|
||||
|
||||
func NewRootCmd() *cobra.Command {
|
||||
rootCmd.AddCommand(newCmd, promptCmd)
|
||||
return rootCmd
|
||||
return rootCmd;
|
||||
}
|
||||
|
@ -37,6 +37,7 @@ func InitializeConfig() *Config {
|
||||
defaultConfig := &Config{}
|
||||
defaultConfig.OpenAI.APIKey = "your_key_here"
|
||||
|
||||
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
Fatal("Could not open config file for writing: %v", err)
|
||||
|
@ -3,22 +3,22 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest {
|
||||
func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "system",
|
||||
Role: "system",
|
||||
Content: system,
|
||||
},
|
||||
}
|
||||
|
||||
for _, m := range messages {
|
||||
for _, m := range(messages) {
|
||||
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Role: m.Role,
|
||||
Content: m.OriginalContent,
|
||||
})
|
||||
}
|
||||
@ -26,8 +26,8 @@ func CreateChatCompletionRequest(system string, messages []Message) *openai.Chat
|
||||
return &openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4,
|
||||
MaxTokens: 256,
|
||||
Messages: chatCompletionMessages,
|
||||
Stream: true,
|
||||
Messages: chatCompletionMessages,
|
||||
Stream: true,
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,17 +47,13 @@ func CreateChatCompletion(system string, messages []Message) (string, error) {
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits an streaming Chat Completion API request
|
||||
// and sends the received data to the output channel.
|
||||
func CreateChatCompletionStream(system string, messages []Message, output chan string) error {
|
||||
func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error {
|
||||
client := openai.NewClient(config.OpenAI.APIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
req := CreateChatCompletionRequest(system, messages)
|
||||
req.Stream = true
|
||||
|
||||
defer close(output)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(ctx, *req)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -72,9 +68,10 @@ func CreateChatCompletionStream(system string, messages []Message, output chan s
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
//fmt.Printf("\nStream error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
output <- response.Choices[0].Delta.Content
|
||||
fmt.Fprint(output, response.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
|
@ -4,33 +4,33 @@ import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/driver/sqlite"
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role string // 'user' or 'assistant'
|
||||
Role string // 'user' or 'assistant'
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
Title string
|
||||
}
|
||||
|
||||
|
||||
func getDataDir() string {
|
||||
var dataDir string
|
||||
var dataDir string;
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
@ -57,7 +57,7 @@ func InitializeStore() *Store {
|
||||
&Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
for _, x := range(models) {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
Fatal("Could not perform database migrations: %v\n", err)
|
||||
@ -76,7 +76,7 @@ func (s *Store) SaveConversation(conversation *Conversation) error {
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
shortName, _ := s.sqids.Encode([]uint64{ uint64(conversation.ID) })
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
@ -1,51 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
fmt.Print("\r")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
|
||||
// content received on the response channel to stdout. Blocks until the channel
|
||||
// is closed.
|
||||
func HandleDelayedResponse(response chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range response {
|
||||
if firstChunk {
|
||||
waitSignal <- ""
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user