Large refactor - it compiles!
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
This commit is contained in:
118
pkg/api/api.go
Normal file
118
pkg/api/api.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Content string // TODO: support multi-part messages
|
||||
Role MessageRole
|
||||
ToolCalls []ToolCall
|
||||
ToolResults []ToolResult
|
||||
}
|
||||
|
||||
type ToolSpec struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id" yaml:"-"`
|
||||
Name string `json:"name" yaml:"tool"`
|
||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||
Result string `json:"result,omitempty" yaml:"result"`
|
||||
}
|
||||
|
||||
func NewMessageWithAssistant(content string) *Message {
|
||||
return &Message{
|
||||
Role: MessageRoleAssistant,
|
||||
Content: content,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message {
|
||||
return &Message{
|
||||
Role: MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
}
|
||||
|
||||
func (m MessageRole) IsAssistant() bool {
|
||||
switch m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m MessageRole) IsSystem() bool {
|
||||
switch m {
|
||||
case MessageRoleSystem:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: remove this
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type MessageMeta struct {
|
||||
GenerationProvider *string `json:"generation_provider,omitempty"`
|
||||
GenerationModel *string `json:"generation_model,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
CreatedAt time.Time
|
||||
Metadata MessageMeta
|
||||
|
||||
ConversationID *uint `gorm:"index"`
|
||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
|
||||
Role MessageRole
|
||||
Content string
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
}
|
||||
|
||||
func (m *MessageMeta) Scan(value interface{}) error {
|
||||
return json.Unmarshal(value.([]byte), m)
|
||||
}
|
||||
|
||||
func (m MessageMeta) Value() (driver.Value, error) {
|
||||
return json.Marshal(m)
|
||||
}
|
||||
|
||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||
if force {
|
||||
m[0].Content = system
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
return append([]Message{{
|
||||
Role: MessageRoleSystem,
|
||||
Content: system,
|
||||
}}, m...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m MessageRole) IsAssistant() bool {
|
||||
switch m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m MessageRole) IsSystem() bool {
|
||||
switch m {
|
||||
case MessageRoleSystem:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
const ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
type AnthropicClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema InputSchema `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
partialJsonAccumulator string
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta interface{} `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
anthropicTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
properties := make(map[string]Property)
|
||||
for _, param := range tool.Parameters {
|
||||
properties[param.Name] = Property{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
}
|
||||
|
||||
var required []string
|
||||
for _, param := range tool.Parameters {
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = Tool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (string, ChatCompletionRequest) {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
var systemMessage string
|
||||
|
||||
for _, m := range messages {
|
||||
if m.Role == api.MessageRoleSystem {
|
||||
systemMessage = m.Content
|
||||
continue
|
||||
}
|
||||
|
||||
var content interface{}
|
||||
role := string(m.Role)
|
||||
|
||||
switch m.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
role = "assistant"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
if m.Content != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, toolCall := range m.ToolCalls {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCall.ID,
|
||||
"name": toolCall.Name,
|
||||
"input": toolCall.Parameters,
|
||||
})
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
case api.MessageRoleToolResult:
|
||||
role = "user"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
for _, result := range m.ToolResults {
|
||||
contentBlock := map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": result.ToolCallID,
|
||||
"content": result.Result,
|
||||
}
|
||||
contentBlocks = append(contentBlocks, contentBlock)
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
default:
|
||||
content = m.Content
|
||||
}
|
||||
|
||||
requestMessages = append(requestMessages, ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
System: systemMessage,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
}
|
||||
|
||||
var prefill string
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
||||
// Prompting on an assitant message, use its content as prefill
|
||||
prefill = messages[len(messages)-1].Content
|
||||
}
|
||||
|
||||
return prefill, request
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||
jsonData, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
_, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return convertResponseToMessage(completionResp)
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
prefill, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentBlocks := make(map[int]*ContentBlock)
|
||||
var finalMessage *ChatCompletionResponse
|
||||
|
||||
var firstChunkReceived bool
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamEvent StreamEvent
|
||||
err = json.Unmarshal(line, &streamEvent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
||||
}
|
||||
|
||||
switch streamEvent.Type {
|
||||
case "message_start":
|
||||
finalMessage = &ChatCompletionResponse{}
|
||||
err = json.Unmarshal(line, &struct {
|
||||
Message *ChatCompletionResponse `json:"message"`
|
||||
}{Message: finalMessage})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
||||
}
|
||||
case "content_block_start":
|
||||
var contentBlockStart struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
err = json.Unmarshal(line, &contentBlockStart)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
||||
}
|
||||
|
||||
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
||||
case "content_block_delta":
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
||||
}
|
||||
|
||||
deltaType, ok := delta["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("delta missing type field")
|
||||
}
|
||||
|
||||
switch deltaType {
|
||||
case "text_delta":
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
if !firstChunkReceived {
|
||||
if prefill == "" {
|
||||
// if there is no prefil, ensure we trim leading whitespace
|
||||
text = strings.TrimSpace(text)
|
||||
}
|
||||
firstChunkReceived = true
|
||||
}
|
||||
block.Text += text
|
||||
output <- provider.Chunk{
|
||||
Content: text,
|
||||
// rough, anthropic performs some chunking
|
||||
TokenCount: uint(len(strings.Split(text, " "))),
|
||||
}
|
||||
}
|
||||
case "input_json_delta":
|
||||
if block.Type != "tool_use" {
|
||||
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
block.partialJsonAccumulator += partialJSON
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
||||
var inputData map[string]interface{}
|
||||
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
||||
}
|
||||
block.Input = inputData
|
||||
}
|
||||
case "message_delta":
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("received message_delta before message_start")
|
||||
}
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
||||
}
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok {
|
||||
finalMessage.StopReason = stopReason
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
// End of the stream
|
||||
goto END_STREAM
|
||||
|
||||
case "error":
|
||||
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
||||
|
||||
default:
|
||||
// Ignore unknown event types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
END_STREAM:
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("no final message received")
|
||||
}
|
||||
|
||||
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
||||
for _, v := range contentBlocks {
|
||||
finalMessage.Content = append(finalMessage.Content, *v)
|
||||
}
|
||||
|
||||
return convertResponseToMessage(*finalMessage)
|
||||
}
|
||||
|
||||
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
||||
content := strings.Builder{}
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
content.WriteString(block.Text)
|
||||
case "tool_use":
|
||||
parameters, ok := block.Input.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Name: block.Name,
|
||||
Parameters: parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
message := &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
message.Role = api.MessageRoleToolCall
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]string `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerationConfig struct {
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type FunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Values []string `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
geminiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
// TODO: proper enum handing
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Values: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
geminiTools[i] = Tool{
|
||||
FunctionDeclarations: []FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "OBJECT",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
for k, v := range call.Parameters {
|
||||
args[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
converted[i].FunctionCall = &FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = call.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||
}
|
||||
results[i] = FunctionResponse{
|
||||
Name: result.ToolName,
|
||||
Response: obj,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
for _, m := range messages[startIdx:] {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
Role: "model",
|
||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range results {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &result,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case api.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case api.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
panic("Unhandled role: " + m.Role)
|
||||
}
|
||||
|
||||
content := Content{
|
||||
Role: role,
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
}
|
||||
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
Temperature: ¶ms.Temperature,
|
||||
TopP: ¶ms.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if system != "" {
|
||||
request.SystemInstruction = &Content{
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: system,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
lastTokenCount := 0
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var resp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||
lastTokenCount += tokens
|
||||
|
||||
choice := resp.Candidates[0]
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- provider.Chunk{
|
||||
Content: part.Text,
|
||||
TokenCount: uint(tokens),
|
||||
}
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
@@ -1,189 +0,0 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OllamaResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message OllamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
message := OllamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
|
||||
request := OllamaRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp OllamaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: completionResp.Message.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var streamResp OllamaResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- provider.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(streamResp.Message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
@@ -1,357 +0,0 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
N int `json:"n"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChoice struct {
|
||||
Message ChatCompletionMessage `json:"message"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoice struct {
|
||||
Delta ChatCompletionMessage `json:"delta"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
openaiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
openaiTools[i].Function = FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "object",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
||||
converted := make([]ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Function.Name = call.Name
|
||||
|
||||
json, _ := json.Marshal(call.Parameters)
|
||||
converted[i].Function.Arguments = string(json)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) ChatCompletionRequest {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "assistant"
|
||||
message.Content = m.Content
|
||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
requestMessages = append(requestMessages, message)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple openAI messages
|
||||
for _, result := range m.ToolResults {
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "tool"
|
||||
message.Content = result.Result
|
||||
message.ToolCallID = result.ToolCallID
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
default:
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = string(m.Role)
|
||||
message.Content = m.Content
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Messages: requestMessages,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||
jsonData, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
for header, val := range c.Headers {
|
||||
req.Header.Set(header, val)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Choices[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content + choice.Message.Content
|
||||
} else {
|
||||
content = choice.Message.Content
|
||||
}
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []ToolCall{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(line, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
var streamResp ChatCompletionStreamResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delta := streamResp.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta.Content) > 0 {
|
||||
output <- provider.Chunk{
|
||||
Content: delta.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
type Chunk struct {
|
||||
Content string
|
||||
TokenCount uint
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
Model string
|
||||
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
Toolbox []api.ToolSpec
|
||||
}
|
||||
|
||||
type ChatCompletionProvider interface {
|
||||
// CreateChatCompletion generates a chat completion response to the
|
||||
// provided messages.
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []api.Message,
|
||||
chunks chan<- Chunk,
|
||||
) (*api.Message, error)
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ToolSpec struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id" yaml:"-"`
|
||||
Name string `json:"name" yaml:"tool"`
|
||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||
Result string `json:"result,omitempty" yaml:"result"`
|
||||
}
|
||||
|
||||
type ToolCalls []ToolCall
|
||||
|
||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tc = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tc)
|
||||
return
|
||||
}
|
||||
|
||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||
if len(tc) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(tc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (tr *ToolResults) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tr = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (tr ToolResults) Value() (driver.Value, error) {
|
||||
if len(tr) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
Reference in New Issue
Block a user