From dfe43179c0b867bddec9f888e2abb97f70311b66 Mon Sep 17 00:00:00 2001 From: Matt Low Date: Sun, 9 Jun 2024 20:45:18 +0000 Subject: [PATCH] Include token count in api.Chunk And calculate the tokens/chunk for gemini responses, fixing the tok/s meter for gemini models. Further, only consider the first candidate of streamed gemini responses. --- pkg/api/api.go | 3 ++- pkg/api/provider/anthropic/anthropic.go | 2 ++ pkg/api/provider/google/google.go | 27 +++++++++++++++---------- pkg/api/provider/ollama/ollama.go | 3 ++- pkg/api/provider/openai/openai.go | 5 +++-- pkg/tui/views/chat/update.go | 2 +- 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/pkg/api/api.go b/pkg/api/api.go index dbb3428..41f52fc 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -9,7 +9,8 @@ import ( type ReplyCallback func(model.Message) type Chunk struct { - Content string + Content string + TokenCount uint } type ChatCompletionClient interface { diff --git a/pkg/api/provider/anthropic/anthropic.go b/pkg/api/provider/anthropic/anthropic.go index 7d14f7f..a7426da 100644 --- a/pkg/api/provider/anthropic/anthropic.go +++ b/pkg/api/provider/anthropic/anthropic.go @@ -244,6 +244,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( sb.WriteString(text) output <- api.Chunk{ Content: text, + TokenCount: 1, } case "content_block_stop": // ignore? @@ -266,6 +267,7 @@ func (c *AnthropicClient) CreateChatCompletionStream( sb.WriteString(FUNCTION_STOP_SEQUENCE) output <- api.Chunk{ Content: FUNCTION_STOP_SEQUENCE, + TokenCount: 1, } funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE diff --git a/pkg/api/provider/google/google.go b/pkg/api/provider/google/google.go index 8e44ab0..06d6cba 100644 --- a/pkg/api/provider/google/google.go +++ b/pkg/api/provider/google/google.go @@ -366,6 +366,8 @@ func (c *Client) CreateChatCompletionStream( var toolCalls []FunctionCall reader := bufio.NewReader(resp.Body) + + lastTokenCount := 0 for { line, err := reader.ReadBytes('\n') if err != nil { @@ -382,22 +384,25 @@ func (c *Client) CreateChatCompletionStream( line = bytes.TrimPrefix(line, []byte("data: ")) - var streamResp GenerateContentResponse - err = json.Unmarshal(line, &streamResp) + var resp GenerateContentResponse + err = json.Unmarshal(line, &resp) if err != nil { return "", err } - for _, candidate := range streamResp.Candidates { - for _, part := range candidate.Content.Parts { - if part.FunctionCall != nil { - toolCalls = append(toolCalls, *part.FunctionCall) - } else if part.Text != "" { - output <- api.Chunk { - Content: part.Text, - } - content.WriteString(part.Text) + tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount + lastTokenCount += tokens + + choice := resp.Candidates[0] + for _, part := range choice.Content.Parts { + if part.FunctionCall != nil { + toolCalls = append(toolCalls, *part.FunctionCall) + } else if part.Text != "" { + output <- api.Chunk{ + Content: part.Text, + TokenCount: uint(tokens), } + content.WriteString(part.Text) } } } diff --git a/pkg/api/provider/ollama/ollama.go b/pkg/api/provider/ollama/ollama.go index 4825fa3..a7b9ca5 100644 --- a/pkg/api/provider/ollama/ollama.go +++ b/pkg/api/provider/ollama/ollama.go @@ -182,7 +182,8 @@ func (c *OllamaClient) CreateChatCompletionStream( if len(streamResp.Message.Content) > 0 { output <- api.Chunk{ - Content: streamResp.Message.Content, + Content: streamResp.Message.Content, + TokenCount: 1, } content.WriteString(streamResp.Message.Content) } diff --git a/pkg/api/provider/openai/openai.go b/pkg/api/provider/openai/openai.go index bc291cd..9bd199c 100644 --- a/pkg/api/provider/openai/openai.go +++ b/pkg/api/provider/openai/openai.go @@ -319,8 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream( } } if len(delta.Content) > 0 { - output <- api.Chunk { - Content: delta.Content, + output <- api.Chunk{ + Content: delta.Content, + TokenCount: 1, } content.WriteString(delta.Content) } diff --git a/pkg/tui/views/chat/update.go b/pkg/tui/views/chat/update.go index 369b65d..e385fe3 100644 --- a/pkg/tui/views/chat/update.go +++ b/pkg/tui/views/chat/update.go @@ -111,7 +111,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.replyCursor.Blink = false cmds = append(cmds, m.replyCursor.BlinkCmd()) - m.tokenCount++ + m.tokenCount += msg.TokenCount m.elapsed = time.Now().Sub(m.startTime) case msgResponse: cmds = append(cmds, m.waitForResponse()) // wait for the next response