Compare commits
2 Commits
4590f1db38
...
a28a7a0054
Author | SHA1 | Date | |
---|---|---|---|
a28a7a0054 | |||
200ec57f29 |
@ -2,8 +2,8 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@ -83,12 +83,14 @@ var newCmd = &cobra.Command{
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
OriginalContent: messageContents,
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
},
|
||||
}
|
||||
|
||||
err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err = CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
return
|
||||
@ -111,12 +113,14 @@ var promptCmd = &cobra.Command{
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
OriginalContent: message,
|
||||
Role: "user",
|
||||
OriginalContent: message,
|
||||
},
|
||||
}
|
||||
|
||||
err := CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
|
||||
receiver := make(chan string)
|
||||
go HandleDelayedResponse(receiver)
|
||||
err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
return
|
||||
@ -128,5 +132,5 @@ var promptCmd = &cobra.Command{
|
||||
|
||||
func NewRootCmd() *cobra.Command {
|
||||
rootCmd.AddCommand(newCmd, promptCmd)
|
||||
return rootCmd;
|
||||
return rootCmd
|
||||
}
|
||||
|
@ -37,7 +37,6 @@ func InitializeConfig() *Config {
|
||||
defaultConfig := &Config{}
|
||||
defaultConfig.OpenAI.APIKey = "your_key_here"
|
||||
|
||||
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
Fatal("Could not open config file for writing: %v", err)
|
||||
|
@ -3,12 +3,12 @@ package cli
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) {
|
||||
func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: "system",
|
||||
@ -16,7 +16,7 @@ func CreateChatCompletionRequest(system string, messages []Message) (*openai.Cha
|
||||
},
|
||||
}
|
||||
|
||||
for _, m := range(messages) {
|
||||
for _, m := range messages {
|
||||
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Content: m.OriginalContent,
|
||||
@ -47,13 +47,17 @@ func CreateChatCompletion(system string, messages []Message) (string, error) {
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error {
|
||||
// CreateChatCompletionStream submits an streaming Chat Completion API request
|
||||
// and sends the received data to the output channel.
|
||||
func CreateChatCompletionStream(system string, messages []Message, output chan string) error {
|
||||
client := openai.NewClient(config.OpenAI.APIKey)
|
||||
ctx := context.Background()
|
||||
|
||||
req := CreateChatCompletionRequest(system, messages)
|
||||
req.Stream = true
|
||||
|
||||
defer close(output)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(ctx, *req)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -68,10 +72,9 @@ func CreateChatCompletionStream(system string, messages []Message, output io.Wri
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
//fmt.Printf("\nStream error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Fprint(output, response.Choices[0].Delta.Content)
|
||||
output <- response.Choices[0].Delta.Content
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,10 @@ import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/driver/sqlite"
|
||||
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
@ -28,9 +29,8 @@ type Conversation struct {
|
||||
Title string
|
||||
}
|
||||
|
||||
|
||||
func getDataDir() string {
|
||||
var dataDir string;
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
@ -57,7 +57,7 @@ func InitializeStore() *Store {
|
||||
&Message{},
|
||||
}
|
||||
|
||||
for _, x := range(models) {
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
Fatal("Could not perform database migrations: %v\n", err)
|
||||
|
51
pkg/cli/tty.go
Normal file
51
pkg/cli/tty.go
Normal file
@ -0,0 +1,51 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
fmt.Print("\r")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
|
||||
// content received on the response channel to stdout. Blocks until the channel
|
||||
// is closed.
|
||||
func HandleDelayedResponse(response chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range response {
|
||||
if firstChunk {
|
||||
waitSignal <- ""
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user