diff --git a/pkg/cli/cmd.go b/pkg/cli/cmd.go index 63faccc..efde7af 100644 --- a/pkg/cli/cmd.go +++ b/pkg/cli/cmd.go @@ -97,33 +97,36 @@ var newCmd = &cobra.Command{ OriginalContent: messageContents, }, } - for _, message := range(messages) { + for _, message := range messages { err = store.SaveMessage(&message) if err != nil { - Warn("Could not save %s message: %v\n", message.Role, err) + Warn("Could not save %s message: %v\n", message.Role, err) } } - fmt.Printf("\n\n%s\n\n", system) - fmt.Printf("\n\n%s\n\n", messageContents) - fmt.Print("\n\n") + for _, message := range messages { + message.RenderTTY(true) + } + + reply := Message{ + ConversationID: conversation.ID, + Role: "assistant", + } + + reply.RenderTTY(false) receiver := make(chan string) response := make(chan string) go func() { response <- HandleDelayedResponse(receiver) }() - err = CreateChatCompletionStream(messages, receiver) if err != nil { Fatal("%v\n", err) } - reply := Message{ - ConversationID: conversation.ID, - Role: "assistant", - OriginalContent: <-response, - } + reply.OriginalContent = <-response + err = store.SaveMessage(&reply) if err != nil { Fatal("Could not save reply: %v\n", err) diff --git a/pkg/cli/tty.go b/pkg/cli/tty.go index 5260162..e2dab81 100644 --- a/pkg/cli/tty.go +++ b/pkg/cli/tty.go @@ -55,3 +55,22 @@ func HandleDelayedResponse(response chan string) string { return sb.String() } + +func (m *Message) RenderTTY(paddingDown bool) { + var friendlyRole string + switch m.Role { + case "user": + friendlyRole = "You" + case "system": + friendlyRole = "System" + case "assistant": + friendlyRole = "Assistant" + } + fmt.Printf("<%s>\n\n", friendlyRole) + if m.OriginalContent != "" { + fmt.Print(m.OriginalContent) + } + if paddingDown { + fmt.Print("\n\n") + } +}