Refactor pkg/lmcli/provider

Moved `ChangeCompletionInterface` to `pkg/api`, moved individual
providers to `pkg/api/provider`
This commit is contained in:
Matt Low 2024-06-09 16:42:53 +00:00
parent d2d946b776
commit a2c860252f
12 changed files with 37 additions and 37 deletions

View File

@ -1,4 +1,4 @@
package provider package api
import ( import (
"context" "context"

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -107,7 +107,7 @@ func (c *AnthropicClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -160,8 +160,8 @@ func (c *AnthropicClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- provider.Chunk, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -242,7 +242,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
return "", fmt.Errorf("invalid text delta") return "", fmt.Errorf("invalid text delta")
} }
sb.WriteString(text) sb.WriteString(text)
output <- provider.Chunk{ output <- api.Chunk{
Content: text, Content: text,
} }
case "content_block_stop": case "content_block_stop":
@ -264,7 +264,7 @@ func (c *AnthropicClient) CreateChatCompletionStream(
} }
sb.WriteString(FUNCTION_STOP_SEQUENCE) sb.WriteString(FUNCTION_STOP_SEQUENCE)
output <- provider.Chunk{ output <- api.Chunk{
Content: FUNCTION_STOP_SEQUENCE, Content: FUNCTION_STOP_SEQUENCE,
} }

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -187,7 +187,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []model.ToolCall, toolCalls []model.ToolCall,
callback provider.ReplyCallback, callback api.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -245,7 +245,7 @@ func (c *Client) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -325,8 +325,8 @@ func (c *Client) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- provider.Chunk, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -393,7 +393,7 @@ func (c *Client) CreateChatCompletionStream(
if part.FunctionCall != nil { if part.FunctionCall != nil {
toolCalls = append(toolCalls, *part.FunctionCall) toolCalls = append(toolCalls, *part.FunctionCall)
} else if part.Text != "" { } else if part.Text != "" {
output <- provider.Chunk { output <- api.Chunk {
Content: part.Text, Content: part.Text,
} }
content.WriteString(part.Text) content.WriteString(part.Text)

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
) )
type OllamaClient struct { type OllamaClient struct {
@ -85,7 +85,7 @@ func (c *OllamaClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -131,8 +131,8 @@ func (c *OllamaClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- provider.Chunk, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -181,7 +181,7 @@ func (c *OllamaClient) CreateChatCompletionStream(
} }
if len(streamResp.Message.Content) > 0 { if len(streamResp.Message.Content) > 0 {
output <- provider.Chunk{ output <- api.Chunk{
Content: streamResp.Message.Content, Content: streamResp.Message.Content,
} }
content.WriteString(streamResp.Message.Content) content.WriteString(streamResp.Message.Content)

View File

@ -10,8 +10,8 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
) )
@ -121,7 +121,7 @@ func handleToolCalls(
params model.RequestParameters, params model.RequestParameters,
content string, content string,
toolCalls []ToolCall, toolCalls []ToolCall,
callback provider.ReplyCallback, callback api.ReplyCallback,
messages []model.Message, messages []model.Message,
) ([]model.Message, error) { ) ([]model.Message, error) {
lastMessage := messages[len(messages)-1] lastMessage := messages[len(messages)-1]
@ -180,7 +180,7 @@ func (c *OpenAIClient) CreateChatCompletion(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -244,8 +244,8 @@ func (c *OpenAIClient) CreateChatCompletionStream(
ctx context.Context, ctx context.Context,
params model.RequestParameters, params model.RequestParameters,
messages []model.Message, messages []model.Message,
callback provider.ReplyCallback, callback api.ReplyCallback,
output chan<- provider.Chunk, output chan<- api.Chunk,
) (string, error) { ) (string, error) {
if len(messages) == 0 { if len(messages) == 0 {
return "", fmt.Errorf("Can't create completion from no messages") return "", fmt.Errorf("Can't create completion from no messages")
@ -319,7 +319,7 @@ func (c *OpenAIClient) CreateChatCompletionStream(
} }
} }
if len(delta.Content) > 0 { if len(delta.Content) > 0 {
output <- provider.Chunk { output <- api.Chunk {
Content: delta.Content, Content: delta.Content,
} }
content.WriteString(delta.Content) content.WriteString(delta.Content)

View File

@ -8,9 +8,9 @@ import (
"strings" "strings"
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/lmcli" "git.mlow.ca/mlow/lmcli/pkg/lmcli"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
@ -18,7 +18,7 @@ import (
// Prompt prompts the configured the configured model and streams the response // Prompt prompts the configured the configured model and streams the response
// to stdout. Returns all model reply messages. // to stdout. Returns all model reply messages.
func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) { func Prompt(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
content := make(chan provider.Chunk) // receives the reponse from LLM content := make(chan api.Chunk) // receives the reponse from LLM
defer close(content) defer close(content)
// render all content received over the channel // render all content received over the channel
@ -252,7 +252,7 @@ func ShowWaitAnimation(signal chan any) {
// chunked) content is received on the channel, the waiting animation is // chunked) content is received on the channel, the waiting animation is
// replaced by the content. // replaced by the content.
// Blocks until the channel is closed. // Blocks until the channel is closed.
func ShowDelayedContent(content <-chan provider.Chunk) { func ShowDelayedContent(content <-chan api.Chunk) {
waitSignal := make(chan any) waitSignal := make(chan any)
go ShowWaitAnimation(waitSignal) go ShowWaitAnimation(waitSignal)

View File

@ -6,12 +6,12 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
"git.mlow.ca/mlow/lmcli/pkg/api"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model" "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/google"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/ollama"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools" "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
"git.mlow.ca/mlow/lmcli/pkg/util" "git.mlow.ca/mlow/lmcli/pkg/util"
"git.mlow.ca/mlow/lmcli/pkg/util/tty" "git.mlow.ca/mlow/lmcli/pkg/util/tty"
@ -79,7 +79,7 @@ func (c *Context) GetModels() (models []string) {
return return
} }
func (c *Context) GetModelProvider(model string) (string, provider.ChatCompletionClient, error) { func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionClient, error) {
parts := strings.Split(model, "/") parts := strings.Split(model, "/")
var provider string var provider string

View File

@ -3,8 +3,8 @@ package chat
import ( import (
"time" "time"
"git.mlow.ca/mlow/lmcli/pkg/api"
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model" models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
"git.mlow.ca/mlow/lmcli/pkg/tui/shared" "git.mlow.ca/mlow/lmcli/pkg/tui/shared"
"github.com/charmbracelet/bubbles/cursor" "github.com/charmbracelet/bubbles/cursor"
"github.com/charmbracelet/bubbles/spinner" "github.com/charmbracelet/bubbles/spinner"
@ -17,7 +17,7 @@ import (
// custom tea.Msg types // custom tea.Msg types
type ( type (
// sent on each chunk received from LLM // sent on each chunk received from LLM
msgResponseChunk provider.Chunk msgResponseChunk api.Chunk
// sent when response is finished being received // sent when response is finished being received
msgResponseEnd string msgResponseEnd string
// a special case of common.MsgError that stops the response waiting animation // a special case of common.MsgError that stops the response waiting animation
@ -83,7 +83,7 @@ type Model struct {
editorTarget editorTarget editorTarget editorTarget
stopSignal chan struct{} stopSignal chan struct{}
replyChan chan models.Message replyChan chan models.Message
replyChunkChan chan provider.Chunk replyChunkChan chan api.Chunk
persistence bool // whether we will save new messages in the conversation persistence bool // whether we will save new messages in the conversation
// ui state // ui state
@ -115,7 +115,7 @@ func Chat(shared shared.Shared) Model {
stopSignal: make(chan struct{}), stopSignal: make(chan struct{}),
replyChan: make(chan models.Message), replyChan: make(chan models.Message),
replyChunkChan: make(chan provider.Chunk), replyChunkChan: make(chan api.Chunk),
wrap: true, wrap: true,
selectedMessage: -1, selectedMessage: -1,