Compare commits
No commits in common. "a1fc8a637b345969fb304d24d18044587ff1867d" and "c50b6b154d0af008db8abf1325b836548fc5c6ee" have entirely different histories.
a1fc8a637b
...
c50b6b154d
2
TODO.md
2
TODO.md
@ -3,7 +3,7 @@
|
|||||||
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
||||||
when calling anthropic?
|
when calling anthropic?
|
||||||
- [x] `dir_tree` tool
|
- [x] `dir_tree` tool
|
||||||
- [x] Implement native Anthropic API tool calling
|
- [ ] Implement native Anthropic API tool calling
|
||||||
- [ ] Agents - a name given to a system prompt + set of available tools +
|
- [ ] Agents - a name given to a system prompt + set of available tools +
|
||||||
potentially other relevent data (e.g. external service credentials, files for
|
potentially other relevent data (e.g. external service credentials, files for
|
||||||
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
||||||
|
@ -40,10 +40,3 @@ type ChatCompletionProvider interface {
|
|||||||
chunks chan<- Chunk,
|
chunks chan<- Chunk,
|
||||||
) (*Message, error)
|
) (*Message, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsAssistantContinuation(messages []Message) bool {
|
|
||||||
if len(messages) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return messages[len(messages)-1].Role == MessageRoleAssistant
|
|
||||||
}
|
|
||||||
|
@ -5,223 +5,100 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ANTHROPIC_VERSION = "2023-06-01"
|
func buildRequest(params api.RequestParameters, messages []api.Message) Request {
|
||||||
|
requestBody := Request{
|
||||||
type AnthropicClient struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content interface{} `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
InputSchema InputSchema `json:"input_schema"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type InputSchema struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]Property `json:"properties"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Property struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []ChatCompletionMessage `json:"messages"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
Stream bool `json:"stream"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContentBlock struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
ID string `json:"id,omitempty"`
|
|
||||||
Name string `json:"name,omitempty"`
|
|
||||||
Input interface{} `json:"input,omitempty"`
|
|
||||||
partialJsonAccumulator string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionResponse struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Model string `json:"model"`
|
|
||||||
Content []ContentBlock `json:"content"`
|
|
||||||
StopReason string `json:"stop_reason"`
|
|
||||||
Usage Usage `json:"usage"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Usage struct {
|
|
||||||
InputTokens int `json:"input_tokens"`
|
|
||||||
OutputTokens int `json:"output_tokens"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type StreamEvent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Message interface{} `json:"message,omitempty"`
|
|
||||||
Index int `json:"index,omitempty"`
|
|
||||||
Delta interface{} `json:"delta,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertTools(tools []api.ToolSpec) []Tool {
|
|
||||||
anthropicTools := make([]Tool, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
properties := make(map[string]Property)
|
|
||||||
for _, param := range tool.Parameters {
|
|
||||||
properties[param.Name] = Property{
|
|
||||||
Type: param.Type,
|
|
||||||
Description: param.Description,
|
|
||||||
Enum: param.Enum,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var required []string
|
|
||||||
for _, param := range tool.Parameters {
|
|
||||||
if param.Required {
|
|
||||||
required = append(required, param.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
anthropicTools[i] = Tool{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
InputSchema: InputSchema{
|
|
||||||
Type: "object",
|
|
||||||
Properties: properties,
|
|
||||||
Required: required,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return anthropicTools
|
|
||||||
}
|
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
|
||||||
params api.RequestParameters,
|
|
||||||
messages []api.Message,
|
|
||||||
) (string, ChatCompletionRequest) {
|
|
||||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
|
||||||
var systemMessage string
|
|
||||||
|
|
||||||
for _, m := range messages {
|
|
||||||
if m.Role == api.MessageRoleSystem {
|
|
||||||
systemMessage = m.Content
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var content interface{}
|
|
||||||
role := string(m.Role)
|
|
||||||
|
|
||||||
switch m.Role {
|
|
||||||
case api.MessageRoleToolCall:
|
|
||||||
role = "assistant"
|
|
||||||
contentBlocks := make([]map[string]interface{}, 0)
|
|
||||||
if m.Content != "" {
|
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
||||||
"type": "text",
|
|
||||||
"text": m.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
for _, toolCall := range m.ToolCalls {
|
|
||||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
|
||||||
"type": "tool_use",
|
|
||||||
"id": toolCall.ID,
|
|
||||||
"name": toolCall.Name,
|
|
||||||
"input": toolCall.Parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
content = contentBlocks
|
|
||||||
|
|
||||||
case api.MessageRoleToolResult:
|
|
||||||
role = "user"
|
|
||||||
contentBlocks := make([]map[string]interface{}, 0)
|
|
||||||
for _, result := range m.ToolResults {
|
|
||||||
contentBlock := map[string]interface{}{
|
|
||||||
"type": "tool_result",
|
|
||||||
"tool_use_id": result.ToolCallID,
|
|
||||||
"content": result.Result,
|
|
||||||
}
|
|
||||||
contentBlocks = append(contentBlocks, contentBlock)
|
|
||||||
}
|
|
||||||
content = contentBlocks
|
|
||||||
|
|
||||||
default:
|
|
||||||
content = m.Content
|
|
||||||
}
|
|
||||||
|
|
||||||
requestMessages = append(requestMessages, ChatCompletionMessage{
|
|
||||||
Role: role,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
request := ChatCompletionRequest{
|
|
||||||
Model: params.Model,
|
Model: params.Model,
|
||||||
Messages: requestMessages,
|
Messages: make([]Message, len(messages)),
|
||||||
System: systemMessage,
|
|
||||||
MaxTokens: params.MaxTokens,
|
MaxTokens: params.MaxTokens,
|
||||||
Temperature: params.Temperature,
|
Temperature: params.Temperature,
|
||||||
|
Stream: false,
|
||||||
|
|
||||||
|
StopSequences: []string{
|
||||||
|
FUNCTION_STOP_SEQUENCE,
|
||||||
|
"\n\nHuman:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
startIdx := 0
|
||||||
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||||
|
requestBody.System = messages[0].Content
|
||||||
|
requestBody.Messages = requestBody.Messages[1:]
|
||||||
|
startIdx = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
if len(params.ToolBag) > 0 {
|
||||||
request.Tools = convertTools(params.ToolBag)
|
if len(requestBody.System) > 0 {
|
||||||
|
// add a divider between existing system prompt and tools
|
||||||
|
requestBody.System += "\n\n---\n\n"
|
||||||
|
}
|
||||||
|
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefill string
|
for i, msg := range messages[startIdx:] {
|
||||||
if api.IsAssistantContinuation(messages) {
|
message := &requestBody.Messages[i]
|
||||||
prefill = messages[len(messages)-1].Content
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefill, request
|
switch msg.Role {
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
message.Role = "assistant"
|
||||||
|
if msg.Content != "" {
|
||||||
|
message.Content = msg.Content
|
||||||
|
}
|
||||||
|
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||||
|
xmlString, err := xmlFuncCalls.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||||
|
}
|
||||||
|
if len(message.Content) > 0 {
|
||||||
|
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
||||||
|
} else {
|
||||||
|
message.Content = xmlString
|
||||||
|
}
|
||||||
|
case api.MessageRoleToolResult:
|
||||||
|
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||||
|
xmlString, err := xmlFuncResults.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||||
|
}
|
||||||
|
message.Role = "user"
|
||||||
|
message.Content = xmlString
|
||||||
|
default:
|
||||||
|
message.Role = string(msg.Role)
|
||||||
|
message.Content = msg.Content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return requestBody
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||||
jsonData, err := json.Marshal(r)
|
jsonBody, err := json.Marshal(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/messages", bytes.NewBuffer(jsonBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("x-api-key", c.APIKey)
|
req.Header.Set("x-api-key", c.APIKey)
|
||||||
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
req.Header.Set("anthropic-version", "2023-06-01")
|
||||||
req.Header.Set("content-type", "application/json")
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
return resp, nil
|
||||||
bytes, _ := io.ReadAll(resp.Body)
|
|
||||||
return resp, fmt.Errorf("%v", string(bytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
@ -230,25 +107,45 @@ func (c *AnthropicClient) CreateChatCompletion(
|
|||||||
messages []api.Message,
|
messages []api.Message,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, req := createChatCompletionRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
req.Stream = false
|
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, req)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
var completionResp ChatCompletionResponse
|
var response Response
|
||||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
return nil, fmt.Errorf("failed to decode response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return convertResponseToMessage(completionResp)
|
sb := strings.Builder{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
// this is a continuation of a previous assistant reply, so we'll
|
||||||
|
// include its contents in the final result
|
||||||
|
sb.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, content := range response.Content {
|
||||||
|
switch content.Type {
|
||||||
|
case "text":
|
||||||
|
sb.WriteString(content.Text)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported message type: %s", content.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: sb.String(),
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
@ -258,193 +155,144 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
|||||||
output chan<- api.Chunk,
|
output chan<- api.Chunk,
|
||||||
) (*api.Message, error) {
|
) (*api.Message, error) {
|
||||||
if len(messages) == 0 {
|
if len(messages) == 0 {
|
||||||
return nil, fmt.Errorf("can't create completion from no messages")
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
}
|
}
|
||||||
|
|
||||||
prefill, req := createChatCompletionRequest(params, messages)
|
request := buildRequest(params, messages)
|
||||||
req.Stream = true
|
request.Stream = true
|
||||||
|
|
||||||
resp, err := c.sendRequest(ctx, req)
|
resp, err := sendRequest(ctx, c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
contentBlocks := make(map[int]*ContentBlock)
|
sb := strings.Builder{}
|
||||||
var finalMessage *ChatCompletionResponse
|
|
||||||
|
|
||||||
var firstChunkReceived bool
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if messages[len(messages)-1].Role.IsAssistant() {
|
||||||
reader := bufio.NewReader(resp.Body)
|
// this is a continuation of a previous assistant reply, so we'll
|
||||||
for {
|
// include its contents in the final result
|
||||||
select {
|
// TODO: handle this at higher level
|
||||||
case <-ctx.Done():
|
sb.WriteString(lastMessage.Content)
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
line, err := reader.ReadBytes('\n')
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimSpace(line)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
for scanner.Scan() {
|
||||||
|
line := scanner.Text()
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
if line[0] == '{' {
|
||||||
|
var event map[string]interface{}
|
||||||
var streamEvent StreamEvent
|
err := json.Unmarshal([]byte(line), &event)
|
||||||
err = json.Unmarshal(line, &streamEvent)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||||
|
}
|
||||||
|
eventType, ok := event["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid event: %s", line)
|
||||||
|
}
|
||||||
|
switch eventType {
|
||||||
|
case "error":
|
||||||
|
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown event type: %s", eventType)
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(line, "data:") {
|
||||||
|
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||||
|
var event map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(data), &event)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch streamEvent.Type {
|
eventType, ok := event["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid event type")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch eventType {
|
||||||
case "message_start":
|
case "message_start":
|
||||||
finalMessage = &ChatCompletionResponse{}
|
// noop
|
||||||
err = json.Unmarshal(line, &struct {
|
case "ping":
|
||||||
Message *ChatCompletionResponse `json:"message"`
|
// signals start of text - currently ignoring
|
||||||
}{Message: finalMessage})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
|
||||||
}
|
|
||||||
case "content_block_start":
|
case "content_block_start":
|
||||||
var contentBlockStart struct {
|
// ignore?
|
||||||
Index int `json:"index"`
|
|
||||||
ContentBlock ContentBlock `json:"content_block"`
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(line, &contentBlockStart)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
|
||||||
case "content_block_delta":
|
case "content_block_delta":
|
||||||
if streamEvent.Index >= len(contentBlocks) {
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
block := contentBlocks[streamEvent.Index]
|
|
||||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
return nil, fmt.Errorf("invalid content block delta")
|
||||||
}
|
}
|
||||||
|
text, ok := delta["text"].(string)
|
||||||
deltaType, ok := delta["type"].(string)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("delta missing type field")
|
return nil, fmt.Errorf("invalid text delta")
|
||||||
}
|
}
|
||||||
|
sb.WriteString(text)
|
||||||
switch deltaType {
|
|
||||||
case "text_delta":
|
|
||||||
if text, ok := delta["text"].(string); ok {
|
|
||||||
if !firstChunkReceived {
|
|
||||||
if prefill == "" {
|
|
||||||
// if there is no prefil, ensure we trim leading whitespace
|
|
||||||
text = strings.TrimSpace(text)
|
|
||||||
}
|
|
||||||
firstChunkReceived = true
|
|
||||||
}
|
|
||||||
block.Text += text
|
|
||||||
output <- api.Chunk{
|
output <- api.Chunk{
|
||||||
Content: text,
|
Content: text,
|
||||||
TokenCount: 1,
|
TokenCount: 1,
|
||||||
}
|
}
|
||||||
}
|
|
||||||
case "input_json_delta":
|
|
||||||
if block.Type != "tool_use" {
|
|
||||||
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
|
||||||
}
|
|
||||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
|
||||||
block.partialJsonAccumulator += partialJSON
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case "content_block_stop":
|
case "content_block_stop":
|
||||||
if streamEvent.Index >= len(contentBlocks) {
|
// ignore?
|
||||||
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
block := contentBlocks[streamEvent.Index]
|
|
||||||
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
|
||||||
var inputData map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
|
||||||
}
|
|
||||||
block.Input = inputData
|
|
||||||
}
|
|
||||||
case "message_delta":
|
case "message_delta":
|
||||||
if finalMessage == nil {
|
delta, ok := event["delta"].(map[string]interface{})
|
||||||
return nil, fmt.Errorf("received message_delta before message_start")
|
|
||||||
}
|
|
||||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
return nil, fmt.Errorf("invalid message delta")
|
||||||
}
|
}
|
||||||
if stopReason, ok := delta["stop_reason"].(string); ok {
|
stopReason, ok := delta["stop_reason"].(string)
|
||||||
finalMessage.StopReason = stopReason
|
if ok && stopReason == "stop_sequence" {
|
||||||
|
stopSequence, ok := delta["stop_sequence"].(string)
|
||||||
|
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||||
|
content := sb.String()
|
||||||
|
|
||||||
|
start := strings.Index(content, "<function_calls>")
|
||||||
|
if start == -1 {
|
||||||
|
return nil, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||||
|
output <- api.Chunk{
|
||||||
|
Content: FUNCTION_STOP_SEQUENCE,
|
||||||
|
TokenCount: 1,
|
||||||
|
}
|
||||||
|
funcCallXml := content[start:] + FUNCTION_STOP_SEQUENCE
|
||||||
|
|
||||||
|
var functionCalls XMLFunctionCalls
|
||||||
|
err := xml.Unmarshal([]byte(funcCallXml), &functionCalls)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.Message{
|
||||||
|
Role: api.MessageRoleToolCall,
|
||||||
|
// function call xml stripped from content for model interop
|
||||||
|
Content: strings.TrimSpace(content[:start]),
|
||||||
|
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
// End of the stream
|
// return the completed message
|
||||||
goto END_STREAM
|
content := sb.String()
|
||||||
|
return &api.Message{
|
||||||
case "error":
|
|
||||||
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
|
||||||
|
|
||||||
default:
|
|
||||||
// Ignore unknown event types
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
END_STREAM:
|
|
||||||
if finalMessage == nil {
|
|
||||||
return nil, fmt.Errorf("no final message received")
|
|
||||||
}
|
|
||||||
|
|
||||||
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
|
||||||
for _, v := range contentBlocks {
|
|
||||||
finalMessage.Content = append(finalMessage.Content, *v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return convertResponseToMessage(*finalMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
|
||||||
content := strings.Builder{}
|
|
||||||
var toolCalls []api.ToolCall
|
|
||||||
|
|
||||||
for _, block := range resp.Content {
|
|
||||||
switch block.Type {
|
|
||||||
case "text":
|
|
||||||
content.WriteString(block.Text)
|
|
||||||
case "tool_use":
|
|
||||||
parameters, ok := block.Input.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, api.ToolCall{
|
|
||||||
ID: block.ID,
|
|
||||||
Name: block.Name,
|
|
||||||
Parameters: parameters,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
message := &api.Message{
|
|
||||||
Role: api.MessageRoleAssistant,
|
Role: api.MessageRoleAssistant,
|
||||||
Content: content.String(),
|
Content: content,
|
||||||
ToolCalls: toolCalls,
|
}, nil
|
||||||
|
case "error":
|
||||||
|
return nil, fmt.Errorf("an error occurred: %s", event["error"])
|
||||||
|
default:
|
||||||
|
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
if err := scanner.Err(); err != nil {
|
||||||
message.Role = api.MessageRoleToolCall
|
return nil, fmt.Errorf("failed to read response body: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return message, nil
|
return nil, fmt.Errorf("unexpected end of stream")
|
||||||
}
|
}
|
||||||
|
232
pkg/api/provider/anthropic/tools.go
Normal file
232
pkg/api/provider/anthropic/tools.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"text/template"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||||
|
|
||||||
|
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
||||||
|
|
||||||
|
You may call them like this:
|
||||||
|
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:`
|
||||||
|
|
||||||
|
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
||||||
|
|
||||||
|
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
||||||
|
|
||||||
|
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
||||||
|
|
||||||
|
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
||||||
|
|
||||||
|
type XMLTools struct {
|
||||||
|
XMLName struct{} `xml:"tools"`
|
||||||
|
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLToolDescription struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Description string `xml:"description"`
|
||||||
|
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLToolParameter struct {
|
||||||
|
Name string `xml:"name"`
|
||||||
|
Type string `xml:"type"`
|
||||||
|
Description string `xml:"description"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionCalls struct {
|
||||||
|
XMLName struct{} `xml:"function_calls"`
|
||||||
|
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionInvoke struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionInvokeParameters struct {
|
||||||
|
String string `xml:",innerxml"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionResults struct {
|
||||||
|
XMLName struct{} `xml:"function_results"`
|
||||||
|
Result []XMLFunctionResult `xml:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type XMLFunctionResult struct {
|
||||||
|
ToolName string `xml:"tool_name"`
|
||||||
|
Stdout string `xml:"stdout"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||||
|
// parameters name to value
|
||||||
|
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||||
|
lines := strings.Split(params, "\n")
|
||||||
|
ret := make(map[string]interface{}, len(lines))
|
||||||
|
for _, line := range lines {
|
||||||
|
i := strings.Index(line, ">")
|
||||||
|
if i == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
j := strings.Index(line, "</")
|
||||||
|
if j == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// chop from after opening < to first > to get parameter name,
|
||||||
|
// then chop after > to first </ to get parameter value
|
||||||
|
ret[line[1:i]] = line[i+1 : j]
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolsToXMLTools(tools []api.ToolSpec) XMLTools {
|
||||||
|
converted := make([]XMLToolDescription, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
converted[i].ToolName = tool.Name
|
||||||
|
converted[i].Description = tool.Description
|
||||||
|
|
||||||
|
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||||
|
for j, param := range tool.Parameters {
|
||||||
|
params[j].Name = param.Name
|
||||||
|
params[j].Description = param.Description
|
||||||
|
params[j].Type = param.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
converted[i].Parameters = params
|
||||||
|
}
|
||||||
|
return XMLTools{
|
||||||
|
ToolDescriptions: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []api.ToolCall {
|
||||||
|
toolCalls := make([]api.ToolCall, len(functionCalls.Invoke))
|
||||||
|
for i, invoke := range functionCalls.Invoke {
|
||||||
|
toolCalls[i].Name = invoke.ToolName
|
||||||
|
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||||
|
}
|
||||||
|
return toolCalls
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallsToXMLFunctionCalls(toolCalls []api.ToolCall) XMLFunctionCalls {
|
||||||
|
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||||
|
for i, toolCall := range toolCalls {
|
||||||
|
var params XMLFunctionInvokeParameters
|
||||||
|
var paramXML string
|
||||||
|
for key, value := range toolCall.Parameters {
|
||||||
|
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||||
|
}
|
||||||
|
params.String = paramXML
|
||||||
|
converted[i] = XMLFunctionInvoke{
|
||||||
|
ToolName: toolCall.Name,
|
||||||
|
Parameters: params,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return XMLFunctionCalls{
|
||||||
|
Invoke: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolResultsToXMLFunctionResult(toolResults []api.ToolResult) XMLFunctionResults {
|
||||||
|
converted := make([]XMLFunctionResult, len(toolResults))
|
||||||
|
for i, result := range toolResults {
|
||||||
|
converted[i].ToolName = result.ToolName
|
||||||
|
converted[i].Stdout = result.Result
|
||||||
|
}
|
||||||
|
return XMLFunctionResults{
|
||||||
|
Result: converted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildToolsSystemPrompt(tools []api.ToolSpec) string {
|
||||||
|
xmlTools := convertToolsToXMLTools(tools)
|
||||||
|
xmlToolsString, err := xmlTools.XMLString()
|
||||||
|
if err != nil {
|
||||||
|
panic("Could not serialize []api.Tool to XMLTools")
|
||||||
|
}
|
||||||
|
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x XMLTools) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("tools").Parse(`<tools>
|
||||||
|
{{range .ToolDescriptions}}<tool_description>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<description>
|
||||||
|
{{.Description}}
|
||||||
|
</description>
|
||||||
|
<parameters>
|
||||||
|
{{range .Parameters}}<parameter>
|
||||||
|
<name>{{.Name}}</name>
|
||||||
|
<type>{{.Type}}</type>
|
||||||
|
<description>{{.Description}}</description>
|
||||||
|
</parameter>
|
||||||
|
{{end}}</parameters>
|
||||||
|
</tool_description>
|
||||||
|
{{end}}</tools>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||||
|
{{range .Result}}<result>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<stdout>{{.Stdout}}</stdout>
|
||||||
|
</result>
|
||||||
|
{{end}}</function_results>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||||
|
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||||
|
{{range .Invoke}}<invoke>
|
||||||
|
<tool_name>{{.ToolName}}</tool_name>
|
||||||
|
<parameters>{{.Parameters.String}}</parameters>
|
||||||
|
</invoke>
|
||||||
|
{{end}}</function_calls>`)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := tmpl.Execute(&buf, x); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String(), nil
|
||||||
|
}
|
38
pkg/api/provider/anthropic/types.go
Normal file
38
pkg/api/provider/anthropic/types.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
type AnthropicClient struct {
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Request struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []Message `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
//TopP float32 `json:"top_p,omitempty"`
|
||||||
|
//TopK float32 `json:"top_k,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OriginalContent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Response struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []OriginalContent `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
StopSequence string `json:"stop_sequence"`
|
||||||
|
}
|
||||||
|
|
@ -13,85 +13,6 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ContentPart struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
||||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Args map[string]string `json:"args"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionResponse struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Response interface{} `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Content struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Parts []ContentPart `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerationConfig struct {
|
|
||||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
||||||
Temperature *float32 `json:"temperature,omitempty"`
|
|
||||||
TopP *float32 `json:"topP,omitempty"`
|
|
||||||
TopK *int `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentRequest struct {
|
|
||||||
Contents []Content `json:"contents"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
|
||||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Candidate struct {
|
|
||||||
Content Content `json:"content"`
|
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
Index int `json:"index"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UsageMetadata struct {
|
|
||||||
PromptTokenCount int `json:"promptTokenCount"`
|
|
||||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
||||||
TotalTokenCount int `json:"totalTokenCount"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GenerateContentResponse struct {
|
|
||||||
Candidates []Candidate `json:"candidates"`
|
|
||||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDeclaration struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Values []string `json:"values,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertTools(tools []api.ToolSpec) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
geminiTools := make([]Tool, len(tools))
|
geminiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
|
80
pkg/api/provider/google/types.go
Normal file
80
pkg/api/provider/google/types.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package google
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]string `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Content struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Parts []ContentPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerationConfig struct {
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentRequest struct {
|
||||||
|
Contents []Content `json:"contents"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||||
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Candidate struct {
|
||||||
|
Content Content `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentResponse struct {
|
||||||
|
Candidates []Candidate `json:"candidates"`
|
||||||
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDeclaration struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Values []string `json:"values,omitempty"`
|
||||||
|
}
|
@ -13,76 +13,6 @@ import (
|
|||||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIClient struct {
|
|
||||||
APIKey string
|
|
||||||
BaseURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionMessage struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content,omitempty"`
|
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Index *int `json:"index,omitempty"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FunctionDefinition struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Parameters ToolParameters `json:"parameters"`
|
|
||||||
Arguments string `json:"arguments,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Function FunctionDefinition `json:"function"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionRequest struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
Messages []ChatCompletionMessage `json:"messages"`
|
|
||||||
N int `json:"n"`
|
|
||||||
Tools []Tool `json:"tools,omitempty"`
|
|
||||||
ToolChoice string `json:"tool_choice,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionChoice struct {
|
|
||||||
Message ChatCompletionMessage `json:"message"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionResponse struct {
|
|
||||||
Choices []ChatCompletionChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamChoice struct {
|
|
||||||
Delta ChatCompletionMessage `json:"delta"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatCompletionStreamResponse struct {
|
|
||||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertTools(tools []api.ToolSpec) []Tool {
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
openaiTools := make([]Tool, len(tools))
|
openaiTools := make([]Tool, len(tools))
|
||||||
for i, tool := range tools {
|
for i, tool := range tools {
|
||||||
|
71
pkg/api/provider/openai/types.go
Normal file
71
pkg/api/provider/openai/types.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
type OpenAIClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDefinition struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Message ChatCompletionMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamChoice struct {
|
||||||
|
Delta ChatCompletionMessage `json:"delta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse struct {
|
||||||
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||||
|
}
|
@ -96,7 +96,7 @@ func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProv
|
|||||||
if m == model {
|
if m == model {
|
||||||
switch *p.Kind {
|
switch *p.Kind {
|
||||||
case "anthropic":
|
case "anthropic":
|
||||||
url := "https://api.anthropic.com"
|
url := "https://api.anthropic.com/v1"
|
||||||
if p.BaseURL != nil {
|
if p.BaseURL != nil {
|
||||||
url = *p.BaseURL
|
url = *p.BaseURL
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user