Compare commits

..

No commits in common. "a28a7a00540706bfcc943e79fb8d290a86d0ffc5" and "4590f1db38b83aeb79b13378fea7e91de05500fd" have entirely different histories.

5 changed files with 33 additions and 90 deletions

View File

@ -2,8 +2,8 @@ package cli
import (
"fmt"
"os"
"strings"
"github.com/spf13/cobra"
)
@ -83,14 +83,12 @@ var newCmd = &cobra.Command{
messages := []Message{
{
Role: "user",
OriginalContent: messageContents,
Role: "user",
},
}
receiver := make(chan string)
go HandleDelayedResponse(receiver)
err = CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
err = CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
if err != nil {
Fatal("%v\n", err)
return
@ -113,14 +111,12 @@ var promptCmd = &cobra.Command{
messages := []Message{
{
Role: "user",
OriginalContent: message,
Role: "user",
},
}
receiver := make(chan string)
go HandleDelayedResponse(receiver)
err := CreateChatCompletionStream("You are a helpful assistant.", messages, receiver)
err := CreateChatCompletionStream("You are a helpful assistant.", messages, os.Stdout)
if err != nil {
Fatal("%v\n", err)
return
@ -132,5 +128,5 @@ var promptCmd = &cobra.Command{
func NewRootCmd() *cobra.Command {
rootCmd.AddCommand(newCmd, promptCmd)
return rootCmd
return rootCmd;
}

View File

@ -37,6 +37,7 @@ func InitializeConfig() *Config {
defaultConfig := &Config{}
defaultConfig.OpenAI.APIKey = "your_key_here"
file, err := os.Create(configFile)
if err != nil {
Fatal("Could not open config file for writing: %v", err)

View File

@ -3,12 +3,12 @@ package cli
import (
"context"
"errors"
"fmt"
"io"
openai "github.com/sashabaranov/go-openai"
)
func CreateChatCompletionRequest(system string, messages []Message) *openai.ChatCompletionRequest {
func CreateChatCompletionRequest(system string, messages []Message) (*openai.ChatCompletionRequest) {
chatCompletionMessages := []openai.ChatCompletionMessage{
{
Role: "system",
@ -16,7 +16,7 @@ func CreateChatCompletionRequest(system string, messages []Message) *openai.Chat
},
}
for _, m := range messages {
for _, m := range(messages) {
chatCompletionMessages = append(chatCompletionMessages, openai.ChatCompletionMessage{
Role: m.Role,
Content: m.OriginalContent,
@ -47,17 +47,13 @@ func CreateChatCompletion(system string, messages []Message) (string, error) {
return resp.Choices[0].Message.Content, nil
}
// CreateChatCompletionStream submits an streaming Chat Completion API request
// and sends the received data to the output channel.
func CreateChatCompletionStream(system string, messages []Message, output chan string) error {
func CreateChatCompletionStream(system string, messages []Message, output io.Writer) error {
client := openai.NewClient(config.OpenAI.APIKey)
ctx := context.Background()
req := CreateChatCompletionRequest(system, messages)
req.Stream = true
defer close(output)
stream, err := client.CreateChatCompletionStream(ctx, *req)
if err != nil {
return err
@ -72,9 +68,10 @@ func CreateChatCompletionStream(system string, messages []Message, output chan s
}
if err != nil {
//fmt.Printf("\nStream error: %v\n", err)
return err
}
output <- response.Choices[0].Delta.Content
fmt.Fprint(output, response.Choices[0].Delta.Content)
}
}

View File

@ -4,10 +4,9 @@ import (
"database/sql"
"os"
"path/filepath"
sqids "github.com/sqids/sqids-go"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/driver/sqlite"
sqids "github.com/sqids/sqids-go"
)
type Store struct {
@ -29,8 +28,9 @@ type Conversation struct {
Title string
}
func getDataDir() string {
var dataDir string
var dataDir string;
xdgDataHome := os.Getenv("XDG_DATA_HOME")
if xdgDataHome != "" {
@ -57,7 +57,7 @@ func InitializeStore() *Store {
&Message{},
}
for _, x := range models {
for _, x := range(models) {
err := db.AutoMigrate(x)
if err != nil {
Fatal("Could not perform database migrations: %v\n", err)

View File

@ -1,51 +0,0 @@
package cli
import (
"fmt"
"time"
)
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
// received on the signal channel. An empty string sent to the channel to
// noftify the caller that the animation has completed (carriage returned).
func ShowWaitAnimation(signal chan any) {
animationStep := 0
for {
select {
case _ = <-signal:
fmt.Print("\r")
signal <- ""
return
default:
modSix := animationStep % 6
if modSix == 3 || modSix == 0 {
fmt.Print("\r")
}
if modSix < 3 {
fmt.Print(".")
} else {
fmt.Print(" ")
}
animationStep++
time.Sleep(250 * time.Millisecond)
}
}
}
// HandledDelayedResponse writes a waiting animation (abusing \r) and the
// content received on the response channel to stdout. Blocks until the channel
// is closed.
func HandleDelayedResponse(response chan string) {
waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal)
firstChunk := true
for chunk := range response {
if firstChunk {
waitSignal <- ""
<-waitSignal
firstChunk = false
}
fmt.Print(chunk)
}
}