190 lines
4.1 KiB
Go
190 lines
4.1 KiB
Go
package ollama
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/api/provider"
|
|
)
|
|
|
|
type OllamaClient struct {
|
|
BaseURL string
|
|
}
|
|
|
|
type OllamaMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type OllamaRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []OllamaMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type OllamaResponse struct {
|
|
Model string `json:"model"`
|
|
CreatedAt string `json:"created_at"`
|
|
Message OllamaMessage `json:"message"`
|
|
Done bool `json:"done"`
|
|
TotalDuration uint64 `json:"total_duration,omitempty"`
|
|
LoadDuration uint64 `json:"load_duration,omitempty"`
|
|
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
|
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
|
EvalCount uint64 `json:"eval_count,omitempty"`
|
|
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
|
}
|
|
|
|
func createOllamaRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) OllamaRequest {
|
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
|
|
|
for _, m := range messages {
|
|
message := OllamaMessage{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
}
|
|
requestMessages = append(requestMessages, message)
|
|
}
|
|
|
|
request := OllamaRequest{
|
|
Model: params.Model,
|
|
Messages: requestMessages,
|
|
}
|
|
|
|
return request
|
|
}
|
|
|
|
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *OllamaClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createOllamaRequest(params, messages)
|
|
req.Stream = false
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp OllamaResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &api.Message{
|
|
Role: api.MessageRoleAssistant,
|
|
Content: completionResp.Message.Content,
|
|
}, nil
|
|
}
|
|
|
|
func (c *OllamaClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createOllamaRequest(params, messages)
|
|
req.Stream = true
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
content := strings.Builder{}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
var streamResp OllamaResponse
|
|
err = json.Unmarshal(line, &streamResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(streamResp.Message.Content) > 0 {
|
|
output <- provider.Chunk{
|
|
Content: streamResp.Message.Content,
|
|
TokenCount: 1,
|
|
}
|
|
content.WriteString(streamResp.Message.Content)
|
|
}
|
|
}
|
|
|
|
return &api.Message{
|
|
Role: api.MessageRoleAssistant,
|
|
Content: content.String(),
|
|
}, nil
|
|
}
|