Moved api.ChatCompletionProvider, api.Chunk to api/provider
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
package api
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
@@ -9,3 +12,70 @@ type Conversation struct {
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID *uint `gorm:"index"`
|
||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||
if force {
|
||||
m[0].Content = system
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
return append([]Message{{
|
||||
Role: MessageRoleSystem,
|
||||
Content: system,
|
||||
}}, m...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID *uint `gorm:"index"`
|
||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||
if force {
|
||||
m[0].Content = system
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
return append([]Message{{
|
||||
Role: MessageRoleSystem,
|
||||
Content: system,
|
||||
}}, m...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
const ANTHROPIC_VERSION = "2023-06-01"
|
||||
@@ -117,7 +118,7 @@ func convertTools(tools []api.ToolSpec) []Tool {
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (string, ChatCompletionRequest) {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
@@ -188,7 +189,8 @@ func createChatCompletionRequest(
|
||||
}
|
||||
|
||||
var prefill string
|
||||
if api.IsAssistantContinuation(messages) {
|
||||
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
||||
// Prompting on an assitant message, use its content as prefill
|
||||
prefill = messages[len(messages)-1].Content
|
||||
}
|
||||
|
||||
@@ -226,7 +228,7 @@ func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionReque
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
@@ -253,9 +255,9 @@ func (c *AnthropicClient) CreateChatCompletion(
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
@@ -349,7 +351,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
firstChunkReceived = true
|
||||
}
|
||||
block.Text += text
|
||||
output <- api.Chunk{
|
||||
output <- provider.Chunk{
|
||||
Content: text,
|
||||
TokenCount: 1,
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
@@ -172,7 +173,7 @@ func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionRespons
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
@@ -279,7 +280,7 @@ func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
@@ -351,9 +352,9 @@ func (c *Client) CreateChatCompletion(
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
@@ -425,7 +426,7 @@ func (c *Client) CreateChatCompletionStream(
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- api.Chunk{
|
||||
output <- provider.Chunk{
|
||||
Content: part.Text,
|
||||
TokenCount: uint(tokens),
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
@@ -42,7 +43,7 @@ type OllamaResponse struct {
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
@@ -82,7 +83,7 @@ func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
@@ -122,9 +123,9 @@ func (c *OllamaClient) CreateChatCompletion(
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
@@ -173,7 +174,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
output <- provider.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
@@ -140,7 +141,7 @@ func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) ChatCompletionRequest {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
@@ -219,7 +220,7 @@ func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest)
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
@@ -267,9 +268,9 @@ func (c *OpenAIClient) CreateChatCompletion(
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
params provider.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
output chan<- provider.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
@@ -333,7 +334,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
}
|
||||
}
|
||||
if len(delta.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
output <- provider.Chunk{
|
||||
Content: delta.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package api
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
type ReplyCallback func(Message)
|
||||
type ReplyCallback func(api.Message)
|
||||
|
||||
type Chunk struct {
|
||||
Content string
|
||||
@@ -18,7 +20,7 @@ type RequestParameters struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
Toolbox []ToolSpec
|
||||
Toolbox []api.ToolSpec
|
||||
}
|
||||
|
||||
type ChatCompletionProvider interface {
|
||||
@@ -28,22 +30,15 @@ type ChatCompletionProvider interface {
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []Message,
|
||||
) (*Message, error)
|
||||
messages []api.Message,
|
||||
) (*api.Message, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []Message,
|
||||
messages []api.Message,
|
||||
chunks chan<- Chunk,
|
||||
) (*Message, error)
|
||||
}
|
||||
|
||||
func IsAssistantContinuation(messages []Message) bool {
|
||||
if len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
return messages[len(messages)-1].Role == MessageRoleAssistant
|
||||
) (*api.Message, error)
|
||||
}
|
||||
Reference in New Issue
Block a user