package google import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "git.mlow.ca/mlow/lmcli/pkg/api" ) type Client struct { APIKey string BaseURL string } type ContentPart struct { Text string `json:"text,omitempty"` FunctionCall *FunctionCall `json:"functionCall,omitempty"` FunctionResp *FunctionResponse `json:"functionResponse,omitempty"` } type FunctionCall struct { Name string `json:"name"` Args map[string]string `json:"args"` } type FunctionResponse struct { Name string `json:"name"` Response interface{} `json:"response"` } type Content struct { Role string `json:"role"` Parts []ContentPart `json:"parts"` } type GenerationConfig struct { MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` Temperature *float32 `json:"temperature,omitempty"` TopP *float32 `json:"topP,omitempty"` TopK *int `json:"topK,omitempty"` } type GenerateContentRequest struct { Contents []Content `json:"contents"` Tools []Tool `json:"tools,omitempty"` SystemInstruction *Content `json:"systemInstruction,omitempty"` GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"` } type Candidate struct { Content Content `json:"content"` FinishReason string `json:"finishReason"` Index int `json:"index"` } type UsageMetadata struct { PromptTokenCount int `json:"promptTokenCount"` CandidatesTokenCount int `json:"candidatesTokenCount"` TotalTokenCount int `json:"totalTokenCount"` } type GenerateContentResponse struct { Candidates []Candidate `json:"candidates"` UsageMetadata UsageMetadata `json:"usageMetadata"` } type Tool struct { FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"` } type FunctionDeclaration struct { Name string `json:"name"` Description string `json:"description"` Parameters ToolParameters `json:"parameters"` } type ToolParameters struct { Type string `json:"type"` Properties map[string]ToolParameter `json:"properties,omitempty"` Required []string `json:"required,omitempty"` } type ToolParameter struct { Type string `json:"type"` Description string `json:"description"` Values []string `json:"values,omitempty"` } func convertTools(tools []api.ToolSpec) []Tool { geminiTools := make([]Tool, len(tools)) for i, tool := range tools { params := make(map[string]ToolParameter) var required []string for _, param := range tool.Parameters { // TODO: proper enum handing params[param.Name] = ToolParameter{ Type: param.Type, Description: param.Description, Values: param.Enum, } if param.Required { required = append(required, param.Name) } } geminiTools[i] = Tool{ FunctionDeclarations: []FunctionDeclaration{ { Name: tool.Name, Description: tool.Description, Parameters: ToolParameters{ Type: "OBJECT", Properties: params, Required: required, }, }, }, } } return geminiTools } func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart { converted := make([]ContentPart, len(toolCalls)) for i, call := range toolCalls { args := make(map[string]string) for k, v := range call.Parameters { args[k] = fmt.Sprintf("%v", v) } converted[i].FunctionCall = &FunctionCall{ Name: call.Name, Args: args, } } return converted } func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall { converted := make([]api.ToolCall, len(functionCalls)) for i, call := range functionCalls { params := make(map[string]interface{}) for k, v := range call.Args { params[k] = v } converted[i].Name = call.Name converted[i].Parameters = params } return converted } func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) { results := make([]FunctionResponse, len(toolResults)) for i, result := range toolResults { var obj interface{} err := json.Unmarshal([]byte(result.Result), &obj) if err != nil { return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err) } results[i] = FunctionResponse{ Name: result.ToolName, Response: obj, } } return results, nil } func createGenerateContentRequest( params api.RequestParameters, messages []api.Message, ) (*GenerateContentRequest, error) { requestContents := make([]Content, 0, len(messages)) startIdx := 0 var system string if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem { system = messages[0].Content startIdx = 1 } for _, m := range messages[startIdx:] { switch m.Role { case "tool_call": content := Content{ Role: "model", Parts: convertToolCallToGemini(m.ToolCalls), } requestContents = append(requestContents, content) case "tool_result": results, err := convertToolResultsToGemini(m.ToolResults) if err != nil { return nil, err } // expand tool_result messages' results into multiple gemini messages for _, result := range results { content := Content{ Role: "function", Parts: []ContentPart{ { FunctionResp: &result, }, }, } requestContents = append(requestContents, content) } default: var role string switch m.Role { case api.MessageRoleAssistant: role = "model" case api.MessageRoleUser: role = "user" } if role == "" { panic("Unhandled role: " + m.Role) } content := Content{ Role: role, Parts: []ContentPart{ { Text: m.Content, }, }, } requestContents = append(requestContents, content) } } request := &GenerateContentRequest{ Contents: requestContents, GenerationConfig: &GenerationConfig{ MaxOutputTokens: ¶ms.MaxTokens, Temperature: ¶ms.Temperature, TopP: ¶ms.TopP, }, } if system != "" { request.SystemInstruction = &Content{ Parts: []ContentPart{ { Text: system, }, }, } } if len(params.Toolbox) > 0 { request.Tools = convertTools(params.Toolbox) } return request, nil } func (c *Client) sendRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Content-Type", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, err } if resp.StatusCode != 200 { bytes, _ := io.ReadAll(resp.Body) return resp, fmt.Errorf("%v", string(bytes)) } return resp, err } func (c *Client) CreateChatCompletion( ctx context.Context, params api.RequestParameters, messages []api.Message, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { return nil, err } jsonData, err := json.Marshal(req) if err != nil { return nil, err } url := fmt.Sprintf( "%s/v1beta/models/%s:generateContent?key=%s", c.BaseURL, params.Model, c.APIKey, ) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } resp, err := c.sendRequest(httpReq) if err != nil { return nil, err } defer resp.Body.Close() var completionResp GenerateContentResponse err = json.NewDecoder(resp.Body).Decode(&completionResp) if err != nil { return nil, err } choice := completionResp.Candidates[0] var content string lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content = lastMessage.Content } var toolCalls []FunctionCall for _, part := range choice.Content.Parts { if part.Text != "" { content += part.Text } if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } } if len(toolCalls) > 0 { return &api.Message{ Role: api.MessageRoleToolCall, Content: content, ToolCalls: convertToolCallToAPI(toolCalls), }, nil } return &api.Message{ Role: api.MessageRoleAssistant, Content: content, }, nil } func (c *Client) CreateChatCompletionStream( ctx context.Context, params api.RequestParameters, messages []api.Message, output chan<- api.Chunk, ) (*api.Message, error) { if len(messages) == 0 { return nil, fmt.Errorf("Can't create completion from no messages") } req, err := createGenerateContentRequest(params, messages) if err != nil { return nil, err } jsonData, err := json.Marshal(req) if err != nil { return nil, err } url := fmt.Sprintf( "%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse", c.BaseURL, params.Model, c.APIKey, ) httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } resp, err := c.sendRequest(httpReq) if err != nil { return nil, err } defer resp.Body.Close() content := strings.Builder{} lastMessage := messages[len(messages)-1] if lastMessage.Role.IsAssistant() { content.WriteString(lastMessage.Content) } var toolCalls []FunctionCall reader := bufio.NewReader(resp.Body) lastTokenCount := 0 for { line, err := reader.ReadBytes('\n') if err != nil { if err == io.EOF { break } return nil, err } line = bytes.TrimSpace(line) if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) { continue } line = bytes.TrimPrefix(line, []byte("data: ")) var resp GenerateContentResponse err = json.Unmarshal(line, &resp) if err != nil { return nil, err } tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount lastTokenCount += tokens choice := resp.Candidates[0] for _, part := range choice.Content.Parts { if part.FunctionCall != nil { toolCalls = append(toolCalls, *part.FunctionCall) } else if part.Text != "" { output <- api.Chunk{ Content: part.Text, TokenCount: uint(tokens), } content.WriteString(part.Text) } } } // If there are function calls, handle them and recurse if len(toolCalls) > 0 { return &api.Message{ Role: api.MessageRoleToolCall, Content: content.String(), ToolCalls: convertToolCallToAPI(toolCalls), }, nil } return &api.Message{ Role: api.MessageRoleAssistant, Content: content.String(), }, nil }