Compare commits
2 Commits
4590f1db38
...
a28a7a0054
Author | SHA1 | Date | |
---|---|---|---|
a28a7a0054 | |||
200ec57f29 |
@ -2,8 +2,8 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ var replyCmd = &cobra.Command{
|
|||||||
var newCmd = &cobra.Command{
|
var newCmd = &cobra.Command{
|
||||||
Use: "new",
|
Use: "new",
|
||||||
Short: "Start a new conversation",
|
Short: "Start a new conversation",
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
messageContents, err := InputFromEditor("# What would you like to say?\n", "message.*.md")
|
messageContents, err := InputFromEditor("# What would you like to say?\n", "message.*.md")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -83,12 +83,14 @@ var newCmd = &cobra.Command{
|
|||||||
|
|
||||||
messages := []Message{
|
messages := []Message{
|
||||||
{
|
{
|
||||||
|
Role: "user",
|
||||||
OriginalContent: messageContents,
|
OriginalContent: messageContents,
|
||||||
Role: "user",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
receiver := make(chan string)
|
||||||
|
go HandleDelayedResponse(receiver)
|
||||||
|
err = CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("%v\n", err)
|
||||||
return
|
return
|
||||||
@ -99,9 +101,9 @@ var newCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var promptCmd = &cobra.Command{
|
var promptCmd = &cobra.Command{
|
||||||
Use: "prompt",
|
Use: "prompt",
|
||||||
Short: "Do a one-shot prompt",
|
Short: "Do a one-shot prompt",
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
message := strings.Join(args, " ")
|
message := strings.Join(args, " ")
|
||||||
if len(strings.Trim(message, " \t\n")) == 0 {
|
if len(strings.Trim(message, " \t\n")) == 0 {
|
||||||
@ -111,12 +113,14 @@ var promptCmd = &cobra.Command{
|
|||||||
|
|
||||||
messages := []Message{
|
messages := []Message{
|
||||||
{
|
{
|
||||||
|
Role: "user",
|
||||||
OriginalContent: message,
|
OriginalContent: message,
|
||||||
Role: "user",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err := CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
receiver := make(chan string)
|
||||||
|
go HandleDelayedResponse(receiver)
|
||||||
|
err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
Fatal("%v\n", err)
|
||||||
return
|
return
|
||||||
@ -128,5 +132,5 @@ var promptCmd = &cobra.Command{
|
|||||||
|
|
||||||
func NewRootCmd() *cobra.Command {
|
func NewRootCmd() *cobra.Command {
|
||||||
rootCmd.AddCommand(newCmd, promptCmd)
|
rootCmd.AddCommand(newCmd, promptCmd)
|
||||||
return rootCmd;
|
return rootCmd
|
||||||
}
|
}
|
||||||
|
@ -37,7 +37,6 @@ func InitializeConfig() *Config {
|
|||||||
defaultConfig := &Config{}
|
defaultConfig := &Config{}
|
||||||
defaultConfig.OpenAI.APIKey = "your_key_here"
|
defaultConfig.OpenAI.APIKey = "your_key_here"
|
||||||
|
|
||||||
|
|
||||||
file, err := os.Create(configFile)
|
file, err := os.Create(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not open config file for writing: %v", err)
|
Fatal("Could not open config file for writing: %v", err)
|
||||||
|
@ -3,22 +3,22 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
openai "github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) {
|
func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest {
|
||||||
chatCompletionMessages := []openai.ChatCompletionMessage{
|
chatCompletionMessages := []openai.ChatCompletionMessage{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: system,
|
Content: system,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, m := range(messages) {
|
for _, m := range messages {
|
||||||
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
||||||
Role: m.Role,
|
Role: m.Role,
|
||||||
Content: m.OriginalContent,
|
Content: m.OriginalContent,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -26,8 +26,8 @@ func CreateChatCompletionRequest(system string, messages []Message) (*openai.Cha
|
|||||||
return &openai.ChatCompletionRequest{
|
return &openai.ChatCompletionRequest{
|
||||||
Model: openai.GPT4,
|
Model: openai.GPT4,
|
||||||
MaxTokens: 256,
|
MaxTokens: 256,
|
||||||
Messages: chatCompletionMessages,
|
Messages: chatCompletionMessages,
|
||||||
Stream: true,
|
Stream: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -47,13 +47,17 @@ func CreateChatCompletion(system string, messages []Message) (string, error) {
|
|||||||
return resp.Choices[0].Message.Content, nil
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error {
|
// CreateChatCompletionStream submits an streaming Chat Completion API request
|
||||||
|
// and sends the received data to the output channel.
|
||||||
|
func CreateChatCompletionStream(system string, messages []Message, output chan string) error {
|
||||||
client := openai.NewClient(config.OpenAI.APIKey)
|
client := openai.NewClient(config.OpenAI.APIKey)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
req := CreateChatCompletionRequest(system, messages)
|
req := CreateChatCompletionRequest(system, messages)
|
||||||
req.Stream = true
|
req.Stream = true
|
||||||
|
|
||||||
|
defer close(output)
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, *req)
|
stream, err := client.CreateChatCompletionStream(ctx, *req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -68,10 +72,9 @@ func CreateChatCompletionStream(system string, messages []Message, output io.Wri
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//fmt.Printf("\nStream error: %v\n", err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprint(output, response.Choices[0].Delta.Content)
|
output <- response.Choices[0].Delta.Content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,33 +4,33 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"gorm.io/gorm"
|
|
||||||
"gorm.io/driver/sqlite"
|
|
||||||
sqids "github.com/sqids/sqids-go"
|
sqids "github.com/sqids/sqids-go"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
sqids *sqids.Sqids
|
sqids *sqids.Sqids
|
||||||
}
|
}
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||||
Conversation Conversation
|
Conversation Conversation
|
||||||
OriginalContent string
|
OriginalContent string
|
||||||
Role string // 'user' or 'assistant'
|
Role string // 'user' or 'assistant'
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
ID uint `gorm:"primaryKey"`
|
ID uint `gorm:"primaryKey"`
|
||||||
ShortName sql.NullString
|
ShortName sql.NullString
|
||||||
Title string
|
Title string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func getDataDir() string {
|
func getDataDir() string {
|
||||||
var dataDir string;
|
var dataDir string
|
||||||
|
|
||||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||||
if xdgDataHome != "" {
|
if xdgDataHome != "" {
|
||||||
@ -57,7 +57,7 @@ func InitializeStore() *Store {
|
|||||||
&Message{},
|
&Message{},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, x := range(models) {
|
for _, x := range models {
|
||||||
err := db.AutoMigrate(x)
|
err := db.AutoMigrate(x)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("Could not perform database migrations: %v\n", err)
|
Fatal("Could not perform database migrations: %v\n", err)
|
||||||
@ -76,7 +76,7 @@ func (s *Store) SaveConversation(conversation *Conversation) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !conversation.ShortName.Valid {
|
if !conversation.ShortName.Valid {
|
||||||
shortName, _ := s.sqids.Encode([]uint64{ uint64(conversation.ID) })
|
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
err = s.db.Save(&conversation).Error
|
err = s.db.Save(&conversation).Error
|
||||||
}
|
}
|
||||||
|
51
pkg/cli/tty.go
Normal file
51
pkg/cli/tty.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||||
|
// received on the signal channel. An empty string sent to the channel to
|
||||||
|
// noftify the caller that the animation has completed (carriage returned).
|
||||||
|
func ShowWaitAnimation(signal chan any) {
|
||||||
|
animationStep := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case _ = <-signal:
|
||||||
|
fmt.Print("\r")
|
||||||
|
signal <- ""
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
modSix := animationStep % 6
|
||||||
|
if modSix == 3 || modSix == 0 {
|
||||||
|
fmt.Print("\r")
|
||||||
|
}
|
||||||
|
if modSix < 3 {
|
||||||
|
fmt.Print(".")
|
||||||
|
} else {
|
||||||
|
fmt.Print(" ")
|
||||||
|
}
|
||||||
|
animationStep++
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
|
||||||
|
// content received on the response channel to stdout. Blocks until the channel
|
||||||
|
// is closed.
|
||||||
|
func HandleDelayedResponse(response chan string) {
|
||||||
|
waitSignal := make(chan any)
|
||||||
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
|
firstChunk := true
|
||||||
|
for chunk := range response {
|
||||||
|
if firstChunk {
|
||||||
|
waitSignal <- ""
|
||||||
|
<-waitSignal
|
||||||
|
firstChunk = false
|
||||||
|
}
|
||||||
|
fmt.Print(chunk)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user