Matt Low
0384c7cb66
This refactor splits out all conversation concerns into a new `conversation` package. There is now a split between `conversation` and `api`s representation of `Message`, the latter storing the minimum information required for interaction with LLM providers. There is necessary conversation between the two when making LLM calls.
184 lines
4.0 KiB
Go
184 lines
4.0 KiB
Go
package ollama
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
|
)
|
|
|
|
type OllamaClient struct {
|
|
BaseURL string
|
|
}
|
|
|
|
type OllamaMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type OllamaRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []OllamaMessage `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
}
|
|
|
|
type OllamaResponse struct {
|
|
Model string `json:"model"`
|
|
CreatedAt string `json:"created_at"`
|
|
Message OllamaMessage `json:"message"`
|
|
Done bool `json:"done"`
|
|
TotalDuration uint64 `json:"total_duration,omitempty"`
|
|
LoadDuration uint64 `json:"load_duration,omitempty"`
|
|
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
|
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
|
EvalCount uint64 `json:"eval_count,omitempty"`
|
|
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
|
}
|
|
|
|
func createOllamaRequest(
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) OllamaRequest {
|
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
|
|
|
for _, m := range messages {
|
|
message := OllamaMessage{
|
|
Role: string(m.Role),
|
|
Content: m.Content,
|
|
}
|
|
requestMessages = append(requestMessages, message)
|
|
}
|
|
|
|
request := OllamaRequest{
|
|
Model: params.Model,
|
|
Messages: requestMessages,
|
|
}
|
|
|
|
return request
|
|
}
|
|
|
|
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
bytes, _ := io.ReadAll(resp.Body)
|
|
return resp, fmt.Errorf("%v", string(bytes))
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *OllamaClient) CreateChatCompletion(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createOllamaRequest(params, messages)
|
|
req.Stream = false
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
var completionResp OllamaResponse
|
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
|
|
}
|
|
|
|
func (c *OllamaClient) CreateChatCompletionStream(
|
|
ctx context.Context,
|
|
params provider.RequestParameters,
|
|
messages []api.Message,
|
|
output chan<- provider.Chunk,
|
|
) (*api.Message, error) {
|
|
if len(messages) == 0 {
|
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
|
}
|
|
|
|
req := createOllamaRequest(params, messages)
|
|
req.Stream = true
|
|
|
|
jsonData, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := c.sendRequest(httpReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
content := strings.Builder{}
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
line = bytes.TrimSpace(line)
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
var streamResp OllamaResponse
|
|
err = json.Unmarshal(line, &streamResp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(streamResp.Message.Content) > 0 {
|
|
output <- provider.Chunk{
|
|
Content: streamResp.Message.Content,
|
|
TokenCount: 1,
|
|
}
|
|
content.WriteString(streamResp.Message.Content)
|
|
}
|
|
}
|
|
|
|
return api.NewMessageWithAssistant(content.String()), nil
|
|
}
|