Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
437 lines
9.9 KiB
Go
437 lines
9.9 KiB
Go
package google
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
)
|
|
|
|
type Client struct {
|
|
APIKey string
|
|
BaseURL string
|
|
}
|
|
|
|
type ContentPart struct {
|
|
Text string `json:"text,omitempty"`
|
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
|
}
|
|
|
|
type FunctionCall struct {
|
|
Name string `json:"name"`
|
|
Args map[string]string `json:"args"`
|
|
}
|
|
|
|
type FunctionResponse struct {
|
|
Name string `json:"name"`
|
|
Response interface{} `json:"response"`
|
|
}
|
|
|
|
type Content struct {
|
|
Role string `json:"role"`
|
|
Parts []ContentPart `json:"parts"`
|
|
}
|
|
|
|
type GenerationConfig struct {
|
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
Temperature *float32 `json:"temperature,omitempty"`
|
|
TopP *float32 `json:"topP,omitempty"`
|
|
TopK *int `json:"topK,omitempty"`
|
|
}
|
|
|
|
type GenerateContentRequest struct {
|
|
Contents []Content `json:"contents"`
|
|
Tools []Tool `json:"tools,omitempty"`
|
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
|
}
|
|
|
|
type Candidate struct {
|
|
Content Content `json:"content"`
|
|
FinishReason string `json:"finishReason"`
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type UsageMetadata struct {
|
|
PromptTokenCount int `json:"promptTokenCount"`
|
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
TotalTokenCount int `json:"totalTokenCount"`
|
|
}
|
|
|
|
type GenerateContentResponse struct {
|
|
Candidates []Candidate `json:"candidates"`
|
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
}
|
|
|
|
type Tool struct {
|
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
|
}
|
|
|
|
type FunctionDeclaration struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Parameters ToolParameters `json:"parameters"`
|
|
}
|
|
|
|
type ToolParameters struct {
|
|
Type string `json:"type"`
|
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
Required []string `json:"required,omitempty"`
|
|
}
|
|
|
|
type ToolParameter struct {
|
|
Type string `json:"type"`
|
|
Description string `json:"description"`
|
|
Values []string `json:"values,omitempty"`
|
|
}
|
|
|
|
func convertTools(tools []api.ToolSpec) []Tool {
|
|
geminiTools := make([]Tool, len(tools))
|
|
for i, tool := range tools {
|
|
params := make(map[string]ToolParameter)
|
|
var required []string
|
|
|
|
for _, param := range tool.Parameters {
|
|
// TODO: proper enum handing
|
|
params[param.Name] = ToolParameter{
|
|
Type: param.Type,
|
|
Description: param.Description,
|
|
Values: param.Enum,
|
|
}
|
|
if param.Required {
|
|
required = append(required, param.Name)
|
|
}
|
|
}
|
|
|
|
geminiTools[i] = Tool{
|
|
FunctionDeclarations: []FunctionDeclaration{
|
|
{
|
|
Name: tool.Name,
|
|
Description: tool.Description,
|
|
Parameters: ToolParameters{
|
|
Type: "OBJECT",
|
|
Properties: params,
|
|
Required: required,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
return geminiTools
|
|
}
|
|
|
|
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
|
converted := make([]ContentPart, len(toolCalls))
|
|
for i, call := range toolCalls {
|
|
args := make(map[string]string)
|
|
for k, v := range call.Parameters {
|
|
args[k] = fmt.Sprintf("%v", v)
|
|
}
|
|
converted[i].FunctionCall = &FunctionCall{
|
|
Name: call.Name,
|
|
Args: args,
|
|
}
|
|
}
|
|
return converted
|
|
}
|
|
|
|
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
|
converted := make([]api.ToolCall, len(functionCalls))
|
|
for i, call := range functionCalls {
|
|
params := make(map[string]interface{})
|
|
for k, v := range call.Args {
|
|
params[k] = v
|
|
}
|
|
converted[i].Name = call.Name
|
|
converted[i].Parameters = params
|
|
}
|
|
return converted
|
|
}
|
|
|
|
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
|
results := make([]FunctionResponse, len(toolResults))
|
|
for i, result := range toolResults {
|
|
var obj interface{}
|
|
err := json.Unmarshal([]byte(result.Result), &obj)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
|
}
|
|
results[i] = FunctionResponse{
|
|
Name: result.ToolName,
|
|
Response: obj,
|
|
}
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func createGenerateContentRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*GenerateContentRequest, error) {
|
|
requestContents := make([]Content, 0, len(messages))
|
|
|
|
startIdx := 0
|
|
var system string
|
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
|
system = messages[0].Content
|
|
startIdx = 1
|
|
}
|
|
|
|
for _, m := range messages[startIdx:] {
|
|
switch m.Role {
|
|
case "tool_call":
|
|
content := Content{
|
|
Role: "model",
|
|
Parts: convertToolCallToGemini(m.ToolCalls),
|
|
}
|
|
requestContents = append(requestContents, content)
|
|
case "tool_result":
|
|
results, err := convertToolResultsToGemini(m.ToolResults)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// expand tool_result messages' results into multiple gemini messages
|
|
for _, result := range results {
|
|
content := Content{
|
|
Role: "function",
|
|
Parts: []ContentPart{
|
|
{
|
|
FunctionResp: &result,
|
|
},
|
|
},
|
|
}
|
|
requestContents = append(requestContents, content)
|
|
}
|
|
default:
|
|
var role string
|
|
switch m.Role {
|
|
case api.MessageRoleAssistant:
|
|
role = "model"
|
|
case api.MessageRoleUser:
|
|
role = "user"
|
|
}
|
|
|
|
if role == "" {
|
|
panic("Unhandled role: " + m.Role)
|
|
}
|
|
|
|
content := Content{
|
|
Role: role,
|
|
Parts: []ContentPart{
|
|
{
|
|
Text: m.Content,
|
|
},
|
|
},
|
|
}
|
|
requestContents = append(requestContents, content)
|
|
}
|
|
}
|
|
|
|
request := &GenerateContentRequest{
|
|
Contents: requestContents,
|
|
GenerationConfig: &GenerationConfig{
|
|
MaxOutputTokens: ¶ms.MaxTokens,
|
|
Temperature: ¶ms.Temperature,
|
|
TopP: ¶ms.TopP,
|
|
},
|
|
}
|
|
|
|
if system != "" {
|
|
request.SystemInstruction = &Content{
|
|
Parts: []ContentPart{
|
|
{
|
|
Text: system,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
if len(params.Toolbox) > 0 {
|
|
request.Tools = convertTools(params.Toolbox)
|
|
}
|
|
|
|
return request, nil
|
|
}
|
|
|
|
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
func (c *Client) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req, err := createGenerateContentRequest(params, messages)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
url := fmt.Sprintf(
|
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
|
c.BaseURL, params.Model, c.APIKey,
|
|
)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp GenerateContentResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
choice := completionResp.Candidates[0]
|
|
|
|
var content string
|
|
lastMessage := messages[len(messages)-1]
|
|
if lastMessage.Role.IsAssistant() {
|
|
content = lastMessage.Content
|
|
}
|
|
|
|
var toolCalls []FunctionCall
|
|
for _, part := range choice.Content.Parts {
|
|
if part.Text != "" {
|
|
content += part.Text
|
|
}
|
|
|
|
if part.FunctionCall != nil {
|
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
}
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content), nil
|
|
}
|
|
|
|
func (c *Client) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req, err := createGenerateContentRequest(params, messages)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
url := fmt.Sprintf(
|
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
|
c.BaseURL, params.Model, c.APIKey,
|
|
)
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
content := strings.Builder{}
|
|
|
|
lastMessage := messages[len(messages)-1]
|
|
if lastMessage.Role.IsAssistant() {
|
|
content.WriteString(lastMessage.Content)
|
|
}
|
|
|
|
var toolCalls []FunctionCall
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
|
|
lastTokenCount := 0
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
|
continue
|
|
}
|
|
|
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
|
|
|
var resp GenerateContentResponse
|
|
err = json.Unmarshal(line, &resp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
|
lastTokenCount += tokens
|
|
|
|
choice := resp.Candidates[0]
|
|
for _, part := range choice.Content.Parts {
|
|
if part.FunctionCall != nil {
|
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
|
} else if part.Text != "" {
|
|
output <- provider.Chunk{
|
|
Content: part.Text,
|
|
TokenCount: uint(tokens),
|
|
}
|
|
content.WriteString(part.Text)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(toolCalls) > 0 {
|
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content.String()), nil
|
|
}
|