Refactor pkg/lmcli/provider
Moved `ChangeCompletionInterface` to `pkg/api`, moved individual providers to `pkg/api/provider`
This commit is contained in:
199
pkg/api/provider/ollama/ollama.go
Normal file
199
pkg/api/provider/ollama/ollama.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OllamaResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message OllamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
message := OllamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
|
||||
request := OllamaRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OllamaClient) sendRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp OllamaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content := completionResp.Message.Content
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback api.ReplyCallback,
|
||||
output chan<- api.Chunk,
|
||||
) (string, error) {
|
||||
if len(messages) == 0 {
|
||||
return "", fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequest("POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(ctx, httpReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var streamResp OllamaResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
}
|
||||
content.WriteString(streamResp.Message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
return content.String(), nil
|
||||
}
|
||||
Reference in New Issue
Block a user