Include token count in api.Chunk
And calculate the tokens/chunk for gemini responses, fixing the tok/s meter for gemini models. Further, only consider the first candidate of streamed gemini responses.
This commit is contained in:
@@ -244,6 +244,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
sb.WriteString(text)
|
||||
output <- api.Chunk{
|
||||
Content: text,
|
||||
TokenCount: 1,
|
||||
}
|
||||
case "content_block_stop":
|
||||
// ignore?
|
||||
@@ -266,6 +267,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- api.Chunk{
|
||||
Content: FUNCTION_STOP_SEQUENCE,
|
||||
TokenCount: 1,
|
||||
}
|
||||
|
||||
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||
|
||||
@@ -366,6 +366,8 @@ func (c *Client) CreateChatCompletionStream(
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
lastTokenCount := 0
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
@@ -382,22 +384,25 @@ func (c *Client) CreateChatCompletionStream(
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamResp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
var resp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &resp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, candidate := range streamResp.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- api.Chunk {
|
||||
Content: part.Text,
|
||||
}
|
||||
content.WriteString(part.Text)
|
||||
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||
lastTokenCount += tokens
|
||||
|
||||
choice := resp.Candidates[0]
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- api.Chunk{
|
||||
Content: part.Text,
|
||||
TokenCount: uint(tokens),
|
||||
}
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +182,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
Content: streamResp.Message.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(streamResp.Message.Content)
|
||||
}
|
||||
|
||||
@@ -319,8 +319,9 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
}
|
||||
}
|
||||
if len(delta.Content) > 0 {
|
||||
output <- api.Chunk {
|
||||
Content: delta.Content,
|
||||
output <- api.Chunk{
|
||||
Content: delta.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user