Include token count in api.Chunk
And calculate the tokens/chunk for gemini responses, fixing the tok/s meter for gemini models. Further, only consider the first candidate of streamed gemini responses.
This commit is contained in:
parent
42c3297e54
commit
dfe43179c0
@ -9,7 +9,8 @@ import (
|
|||||||
type ReplyCallback func(model.Message)
|
type ReplyCallback func(model.Message)
|
||||||
|
|
||||||
type Chunk struct {
|
type Chunk struct {
|
||||||
Content string
|
Content string
|
||||||
|
TokenCount uint
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
type ChatCompletionClient interface {
|
||||||
|
@ -244,6 +244,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
sb.WriteString(text)
|
sb.WriteString(text)
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
Content: text,
|
Content: text,
|
||||||
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
case "content_block_stop":
|
case "content_block_stop":
|
||||||
// ignore?
|
// ignore?
|
||||||
@ -266,6 +267,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
Content: FUNCTION_STOP_SEQUENCE,
|
Content: FUNCTION_STOP_SEQUENCE,
|
||||||
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||||
|
@ -366,6 +366,8 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
var toolCalls []FunctionCall
|
var toolCalls []FunctionCall
|
||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
|
||||||
|
lastTokenCount := 0
|
||||||
for {
|
for {
|
||||||
line, err := reader.ReadBytes('\n')
|
line, err := reader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -382,22 +384,25 @@ func (c *Client) CreateChatCompletionStream(
|
|||||||
|
|
||||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
|
||||||
var streamResp GenerateContentResponse
|
var resp GenerateContentResponse
|
||||||
err = json.Unmarshal(line, &streamResp)
|
err = json.Unmarshal(line, &resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range streamResp.Candidates {
|
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||||
for _, part := range candidate.Content.Parts {
|
lastTokenCount += tokens
|
||||||
if part.FunctionCall != nil {
|
|
||||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
choice := resp.Candidates[0]
|
||||||
} else if part.Text != "" {
|
for _, part := range choice.Content.Parts {
|
||||||
output <- api.Chunk {
|
if part.FunctionCall != nil {
|
||||||
Content: part.Text,
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
}
|
} else if part.Text != "" {
|
||||||
content.WriteString(part.Text)
|
output <- api.Chunk{
|
||||||
|
Content: part.Text,
|
||||||
|
TokenCount: uint(tokens),
|
||||||
}
|
}
|
||||||
|
content.WriteString(part.Text)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -182,7 +182,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
|||||||
|
|
||||||
if len(streamResp.Message.Content) > 0 {
|
if len(streamResp.Message.Content) > 0 {
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
Content: streamResp.Message.Content,
|
Content: streamResp.Message.Content,
|
||||||
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
content.WriteString(streamResp.Message.Content)
|
content.WriteString(streamResp.Message.Content)
|
||||||
}
|
}
|
||||||
|
@ -319,8 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(delta.Content) > 0 {
|
if len(delta.Content) > 0 {
|
||||||
output <- api.Chunk {
|
output <- api.Chunk{
|
||||||
Content: delta.Content,
|
Content: delta.Content,
|
||||||
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
content.WriteString(delta.Content)
|
content.WriteString(delta.Content)
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
|||||||
m.replyCursor.Blink = false
|
m.replyCursor.Blink = false
|
||||||
cmds = append(cmds, m.replyCursor.BlinkCmd())
|
cmds = append(cmds, m.replyCursor.BlinkCmd())
|
||||||
|
|
||||||
m.tokenCount++
|
m.tokenCount += msg.TokenCount
|
||||||
m.elapsed = time.Now().Sub(m.startTime)
|
m.elapsed = time.Now().Sub(m.startTime)
|
||||||
case msgResponse:
|
case msgResponse:
|
||||||
cmds = append(cmds, m.waitForResponse()) // wait for the next response
|
cmds = append(cmds, m.waitForResponse()) // wait for the next response
|
||||||
|
Loading…
Reference in New Issue
Block a user