Include token count in api.Chunk

And calculate the tokens/chunk for gemini responses, fixing the tok/s
meter for gemini models.

Further, only consider the first candidate of streamed gemini responses.
This commit is contained in:
Matt Low 2024-06-09 20:45:18 +00:00
parent 42c3297e54
commit dfe43179c0
6 changed files with 26 additions and 16 deletions

View File

@ -10,6 +10,7 @@ type ReplyCallback func(model.Message)
type Chunk struct { type Chunk struct {
Content string Content string
TokenCount uint
} }
type ChatCompletionClient interface { type ChatCompletionClient interface {

View File

@ -244,6 +244,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
sb.WriteString(text) sb.WriteString(text)
output <- api.Chunk{ output <- api.Chunk{
Content: text, Content: text,
TokenCount: 1,
} }
case "content_block_stop": case "content_block_stop":
// ignore? // ignore?
@ -266,6 +267,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- api.Chunk{ output <- api.Chunk{
Content: FUNCTION_STOP_SEQUENCE, Content: FUNCTION_STOP_SEQUENCE,
TokenCount: 1,
} }
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE

View File

@ -366,6 +366,8 @@ func (c *Client) CreateChatCompletionStream(
var toolCalls []FunctionCall var toolCalls []FunctionCall
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
lastTokenCount := 0
for { for {
line, err := reader.ReadBytes('\n') line, err := reader.ReadBytes('\n')
if err != nil { if err != nil {
@ -382,25 +384,28 @@ func (c *Client) CreateChatCompletionStream(
line = bytes.TrimPrefix(line, []byte("data: ")) line = bytes.TrimPrefix(line, []byte("data: "))
var streamResp GenerateContentResponse var resp GenerateContentResponse
err = json.Unmarshal(line, &streamResp) err = json.Unmarshal(line, &resp)
if err != nil { if err != nil {
return "", err return "", err
} }
for _, candidate := range streamResp.Candidates { tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
for _, part := range candidate.Content.Parts { lastTokenCount += tokens
choice := resp.Candidates[0]
for _, part := range choice.Content.Parts {
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- api.Chunk { output <- api.Chunk{
Content: part.Text, Content: part.Text,
TokenCount: uint(tokens),
} }
content.WriteString(part.Text) content.WriteString(part.Text)
} }
} }
} }
}
// If there are function calls, handle them and recurse // If there are function calls, handle them and recurse
if len(toolCalls) > 0 { if len(toolCalls) > 0 {

View File

@ -183,6 +183,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- api.Chunk{ output <- api.Chunk{
Content: streamResp.Message.Content, Content: streamResp.Message.Content,
TokenCount: 1,
} }
content.WriteString(streamResp.Message.Content) content.WriteString(streamResp.Message.Content)
} }

View File

@ -319,8 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- api.Chunk { output <- api.Chunk{
Content: delta.Content, Content: delta.Content,
TokenCount: 1,
} }
content.WriteString(delta.Content) content.WriteString(delta.Content)
} }

View File

@ -111,7 +111,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
m.replyCursor.Blink = false m.replyCursor.Blink = false
cmds = append(cmds, m.replyCursor.BlinkCmd()) cmds = append(cmds, m.replyCursor.BlinkCmd())
m.tokenCount++ m.tokenCount += msg.TokenCount
m.elapsed = time.Now().Sub(m.startTime) m.elapsed = time.Now().Sub(m.startTime)
case msgResponse: case msgResponse:
cmds = append(cmds, m.waitForResponse()) // wait for the next response cmds = append(cmds, m.waitForResponse()) // wait for the next response