Compare commits
125 Commits
Author | SHA1 | Date | |
---|---|---|---|
f05e2e30f7 | |||
ec21a02ec0 | |||
07c96082e7 | |||
0384c7cb66 | |||
2ea8a73eb5 | |||
c9a7eee090 | |||
ae1e85e166 | |||
304820c919 | |||
93c2fb3d1e | |||
bb48bc9abd | |||
5d13c3e056 | |||
327a128b2f | |||
a441866f2f | |||
ce7b07ad95 | |||
2fed682969 | |||
69cdc0a5aa | |||
3ec2675632 | |||
172bfc57e1 | |||
a46d211e10 | |||
676aa7b004 | |||
b7c89a4dd1 | |||
b8e3172ce0 | |||
a1fdf3f7cd | |||
a488ec4fd8 | |||
463ca9ef40 | |||
24b5cdbbf6 | |||
7c0bfefc65 | |||
443c8096d3 | |||
1570988b98 | |||
434fc4672b | |||
fe838f400f | |||
e59ce973b6 | |||
4ef841e945 | |||
c68084f8a5 | |||
8ca044b6af | |||
6f5cf68208 | |||
914d9ac0c1 | |||
8ddac2f820 | |||
cea5118cac | |||
a43a91c6ff | |||
ba7018af11 | |||
f89cc7b410 | |||
677cfcfebf | |||
11402c5534 | |||
a1fc8a637b | |||
94d84ba7d7 | |||
c50b6b154d | |||
31df055430 | |||
c30e652103 | |||
3fde58b77d | |||
85a2abbbf3 | |||
dfe43179c0 | |||
42c3297e54 | |||
a22119f738 | |||
a2c860252f | |||
d2d946b776 | |||
c963747066 | |||
e334d9fc4f | |||
c1ead83939 | |||
c9e92e186e | |||
45df957a06 | |||
136c463924 | |||
2580087b4d | |||
60a474d516 | |||
ea576d24a6 | |||
465b1d333e | |||
b29a4c8b84 | |||
58e1b84fea | |||
a6522dbcd0 | |||
97cd047861 | |||
ed784bb1cf | |||
c1792f27ff | |||
0ad698a942 | |||
0d66a49997 | |||
008fdc0d37 | |||
eec9eb41e9 | |||
437997872a | |||
3536438dd1 | |||
f5ce970102 | |||
5c1248184b | |||
8c53752146 | |||
f6e55f6bff | |||
dc1edf8c3e | |||
62d98289e8 | |||
b82f3019f0 | |||
1bd953676d | |||
a291e7b42c | |||
1b8d04c96d | |||
cbcd3b1ba9 | |||
75bf9f6125 | |||
9ff4322995 | |||
54f5a3c209 | |||
86bdc733bf | |||
60394de620 | |||
aeeb7bb7f7 | |||
2b38db7db7 | |||
8e4ff90ab4 | |||
bdaf6204f6 | |||
1b9a8f319c | |||
ffe9d299ef | |||
08a2027332 | |||
b06e031ee0 | |||
69d3265b64 | |||
7463b7502c | |||
0e68e22efa | |||
1404cae6a7 | |||
9e6d41a3ff | |||
39cd4227c6 | |||
105ee2e01b | |||
e1970a315a | |||
020db40401 | |||
811ec4b251 | |||
c68cb14eb9 | |||
cef87a55d8 | |||
29519fa2f3 | |||
2e3779ad32 | |||
9cd28d28d7 | |||
0b991800d6 | |||
5af857edae | |||
3e24a54d0a | |||
a669313a0b | |||
6310021dca | |||
ef929da68c | |||
c51644e78e | |||
91c74d9e1e |
199
README.md
199
README.md
@ -1,59 +1,174 @@
|
|||||||
# lmcli
|
# lmcli - Large ____ Model CLI
|
||||||
|
|
||||||
`lmcli` is a (Large) Language Model CLI.
|
`lmcli` is a versatile command-line interface for interacting with LLMs and LMMs.
|
||||||
|
|
||||||
Current features:
|
## Features
|
||||||
- Perform one-shot prompts with `lmcli prompt <message>`
|
|
||||||
- Manage persistent conversations with the `new`, `reply`, `view`, `rm`,
|
|
||||||
`edit`, `retry`, `continue` sub-commands.
|
|
||||||
- Syntax highlighted output
|
|
||||||
- Tool calling, see the [Tools](#tools) section.
|
|
||||||
|
|
||||||
Maybe features:
|
- Multiple model backends (Ollama, OpenAI, Anthropic, Google)
|
||||||
- Chat-like interface (`lmcli chat`) for rapid back-and-forth conversations
|
- Customizable agents with tool calling
|
||||||
- Support for additional models/APIs besides just OpenAI
|
- Persistent conversation management
|
||||||
|
- Message branching (edit and re-prompt to your heart's desire)
|
||||||
|
- Interactive terminal interface for seamless chat experiences
|
||||||
|
- Utilizes `$EDITOR`, write and edit prompts from the comfort of your own editor :)
|
||||||
|
- `vi`-like bindings
|
||||||
|
- Syntax highlighting!
|
||||||
|
|
||||||
|
## Screenshots
|
||||||
|
|
||||||
|
[TODO: Add screenshots of the TUI in action, showing different views and features]
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
To install `lmcli`, make sure you have Go installed on your system, then run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go install git.mlow.ca/mlow/lmcli@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
`lmcli` uses a YAML configuration file located at `~/.config/lmcli/config.yaml`. Here's a sample configuration:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
defaults:
|
||||||
|
model: claude-3-5-sonnet-20240620
|
||||||
|
maxTokens: 3072
|
||||||
|
temperature: 0.2
|
||||||
|
conversations:
|
||||||
|
titleGenerationModel: claude-3-haiku-20240307
|
||||||
|
chroma:
|
||||||
|
style: onedark
|
||||||
|
formatter: terminal16m
|
||||||
|
agents:
|
||||||
|
- name: coder
|
||||||
|
tools:
|
||||||
|
- dir_tree
|
||||||
|
- read_file
|
||||||
|
#- write_file
|
||||||
|
systemPrompt: |
|
||||||
|
You are an experienced software engineer...
|
||||||
|
# ...
|
||||||
|
providers:
|
||||||
|
- kind: ollama
|
||||||
|
models:
|
||||||
|
- phi3:instruct
|
||||||
|
- llama3:8b
|
||||||
|
- kind: anthropic
|
||||||
|
apiKey: your-api-key-here
|
||||||
|
models:
|
||||||
|
- claude-3-5-sonnet-20240620
|
||||||
|
- claude-3-opus-20240229
|
||||||
|
- claude-3-haiku-20240307
|
||||||
|
- kind: openai
|
||||||
|
apiKey: your-api-key-here
|
||||||
|
models:
|
||||||
|
- gpt-4o
|
||||||
|
- gpt-4-turbo
|
||||||
|
- name: openrouter
|
||||||
|
kind: openai
|
||||||
|
apiKey: your-api-key-here
|
||||||
|
baseUrl: https://openrouter.ai/api/
|
||||||
|
models:
|
||||||
|
- qwen/qwen-2-72b-instruct
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Customize this file to add your own providers, agents, and models.
|
||||||
|
|
||||||
|
### Syntax highlighting
|
||||||
|
|
||||||
|
Syntax highlighting is performed by [Chroma](https://github.com/alecthomas/chroma).
|
||||||
|
|
||||||
|
Refer to [`Chroma/styles`](https://github.com/alecthomas/chroma/tree/master/styles) for available styles (TODO: add support for custom Chroma styles).
|
||||||
|
|
||||||
|
Available formatters:
|
||||||
|
- `terminal` - 8 colors
|
||||||
|
- `terminal16` - 16 colors
|
||||||
|
- `terminal256` - 256 colors
|
||||||
|
- `terminal16m` - true color (default)
|
||||||
|
|
||||||
|
## Agents
|
||||||
|
|
||||||
|
Agents in `lmcli` combine a system prompt with a set of available tools. Agents are defined in `config.yaml` and are called upon with the `-a`/`--agent` flag.
|
||||||
|
|
||||||
|
Agent functionality is expected to be expanded on, bringing them to close parity with something like OpenAI's "Assistants" feature.
|
||||||
|
|
||||||
## Tools
|
## Tools
|
||||||
Tools must be explicitly enabled by adding the tool's name to the
|
|
||||||
`openai.enabledTools` array in `config.yaml`.
|
|
||||||
|
|
||||||
Note: all filesystem related tools operate relative to the current directory
|
Tools are used by agents to acquire information from and interact with external systems. The following built-in tools are available:
|
||||||
only. They do not accept absolute paths, and efforts are made to ensure they
|
|
||||||
cannot escape above the working directory). **Close attention must be paid to
|
|
||||||
where you are running `lmcli`, as the model could at any time decide to use one
|
|
||||||
of these tools to discover and read potentially sensitive information from your
|
|
||||||
filesystem.**
|
|
||||||
|
|
||||||
It's best to only have tools enabled in `config.yaml` when you intend to be
|
- `dir_tree`: Display a directory structure
|
||||||
using them, since their descriptions (see `pkg/cli/functions.go`) count towards
|
- `read_file`: Read the contents of a file
|
||||||
context usage.
|
- `write_file`: Write content to a file
|
||||||
|
- `file_insert_lines`: Insert lines at a specific position in a file
|
||||||
|
- `file_replace_lines`: Replace a range of lines in a file
|
||||||
|
|
||||||
Available tools:
|
Obviously, some of these tools carry significant risk. Use wisely :)
|
||||||
|
|
||||||
- `read_dir` - Read the contents of a directory.
|
More tool features are planned, including the ability to define arbitrary tools which call out to external scripts, tools to spawn sub-agents, perform web searches, etc.
|
||||||
- `read_file` - Read the contents of a file.
|
|
||||||
- `write_file` - Write contents to a file.
|
|
||||||
- `file_insert_lines` - Insert lines at a position within a file. Tricky for
|
|
||||||
the model to use, but can potentially save tokens.
|
|
||||||
- `file_replace_lines` - Remove or replace a range of lines within a file. Even
|
|
||||||
trickier for the model to use.
|
|
||||||
|
|
||||||
## Install
|
|
||||||
|
|
||||||
```shell
|
|
||||||
$ go install git.mlow.ca/mlow/lmcli@latest
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Invoke `lmcli` at least once:
|
```console
|
||||||
|
|
||||||
```shell
|
|
||||||
$ lmcli help
|
$ lmcli help
|
||||||
|
lmcli - Large Language Model CLI
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
lmcli <command> [flags]
|
||||||
|
lmcli [command]
|
||||||
|
|
||||||
|
Available Commands:
|
||||||
|
chat Open the chat interface
|
||||||
|
clone Clone conversations
|
||||||
|
completion Generate the autocompletion script for the specified shell
|
||||||
|
continue Continue a conversation from the last message
|
||||||
|
edit Edit the last user reply in a conversation
|
||||||
|
help Help about any command
|
||||||
|
list List conversations
|
||||||
|
new Start a new conversation
|
||||||
|
prompt Do a one-shot prompt
|
||||||
|
rename Rename a conversation
|
||||||
|
reply Reply to a conversation
|
||||||
|
retry Retry the last user reply in a conversation
|
||||||
|
rm Remove conversations
|
||||||
|
view View messages in a conversation
|
||||||
|
|
||||||
|
Flags:
|
||||||
|
-h, --help help for lmcli
|
||||||
|
|
||||||
|
Use "lmcli [command] --help" for more information about a command.
|
||||||
```
|
```
|
||||||
|
|
||||||
Edit `~/.config/lmcli/config.yaml` and set `openai.apiKey` to your API key.
|
### Examples
|
||||||
|
|
||||||
Refer back to the output of `lmcli help` for usage.
|
Start a new chat with the `coder` agent:
|
||||||
|
|
||||||
Enjoy!
|
```console
|
||||||
|
$ lmcli chat --agent coder
|
||||||
|
```
|
||||||
|
|
||||||
|
Start a new conversation, imperative style (no tui):
|
||||||
|
|
||||||
|
```console
|
||||||
|
$ lmcli new "Help me plan meals for the next week"
|
||||||
|
```
|
||||||
|
|
||||||
|
Send a one-shot prompt (no persistence):
|
||||||
|
|
||||||
|
```console
|
||||||
|
$ lmcli prompt "What is the answer to life, the universe, and everything?"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Contributions to `lmcli` are welcome! Feel free to open issues or submit pull requests on the project repository.
|
||||||
|
|
||||||
|
For a full list of planned features and improvements, check out the [TODO.md](TODO.md) file.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
To be determined
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
`lmcli` is a small hobby project. Special thanks to the Go community and the creators of the libraries used in this project.
|
||||||
|
37
TODO.md
Normal file
37
TODO.md
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# TODO
|
||||||
|
|
||||||
|
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
||||||
|
when calling anthropic?
|
||||||
|
- [x] `dir_tree` tool
|
||||||
|
- [x] Implement native Anthropic API tool calling
|
||||||
|
- [x] Agents - a name given to a system prompt + set of available tools +
|
||||||
|
potentially other relevent data (e.g. external service credentials, files for
|
||||||
|
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
||||||
|
code-helper`, `lmcli chat -a financier`).
|
||||||
|
- [ ] Specialized agents which have integrations beyond basic tool calling,
|
||||||
|
e.g. a coding agent which bakes in efficient code context management
|
||||||
|
(only the current state of relevant files get shown to the model in the
|
||||||
|
system prompt, rather than having them in the conversation messages)
|
||||||
|
- [ ] Agents may have some form of long term memory management (key-value?
|
||||||
|
natural lang?).
|
||||||
|
- [ ] Sandboxed python, js interpreters (implemented with containers)
|
||||||
|
- [ ] Support for arbitrary external script tools
|
||||||
|
- [ ] Search - RAG driven search of existing conversation "hey, remind me of
|
||||||
|
the conversation we had six months ago about X")
|
||||||
|
- [ ] Conversation categorization - model driven category creation and
|
||||||
|
conversation classification
|
||||||
|
- [ ] Image input
|
||||||
|
- [ ] Image output (sixel support?)
|
||||||
|
- [ ] Conversation exports to html/pdf/json
|
||||||
|
- [ ] Store message generation model
|
||||||
|
- [ ] Hidden CoT
|
||||||
|
- [ ] Token accounting
|
||||||
|
|
||||||
|
## UI
|
||||||
|
- [x] Prettify/normalize tool_call and tool_result outputs so they can be
|
||||||
|
shown/optionally hidden in `lmcli view` and `lmcli chat`
|
||||||
|
- [x] Conversation deletion in conversations view
|
||||||
|
- [ ] User confirmation before calling (some?) tools
|
||||||
|
- [ ] Message deletion, Ctrl+D to delete a message and attach its children to
|
||||||
|
its parent, Ctrl+Shift+D to delete a message and its descendents
|
||||||
|
- [ ] Show available key bindings and their action in any given view
|
41
go.mod
41
go.mod
@ -3,42 +3,41 @@ module git.mlow.ca/mlow/lmcli
|
|||||||
go 1.21
|
go 1.21
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/alecthomas/chroma/v2 v2.11.1
|
github.com/alecthomas/chroma/v2 v2.14.0
|
||||||
github.com/charmbracelet/bubbles v0.18.0
|
github.com/charmbracelet/bubbles v0.20.0
|
||||||
github.com/charmbracelet/bubbletea v0.25.0
|
github.com/charmbracelet/bubbletea v1.1.1
|
||||||
github.com/charmbracelet/lipgloss v0.10.0
|
github.com/charmbracelet/lipgloss v0.13.0
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
github.com/muesli/reflow v0.3.0
|
||||||
github.com/sashabaranov/go-openai v1.17.7
|
github.com/spf13/cobra v1.8.1
|
||||||
github.com/spf13/cobra v1.8.0
|
|
||||||
github.com/sqids/sqids-go v0.4.1
|
github.com/sqids/sqids-go v0.4.1
|
||||||
gorm.io/driver/sqlite v1.5.4
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/gorm v1.25.5
|
gorm.io/driver/sqlite v1.5.6
|
||||||
|
gorm.io/gorm v1.25.12
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/atotto/clipboard v0.1.4 // indirect
|
github.com/atotto/clipboard v0.1.4 // indirect
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
|
github.com/charmbracelet/x/ansi v0.3.1 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/charmbracelet/x/term v0.2.0 // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.11.4 // indirect
|
||||||
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
github.com/mattn/go-sqlite3 v1.14.23 // indirect
|
||||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/muesli/reflow v0.3.0 // indirect
|
|
||||||
github.com/muesli/termenv v0.15.2 // indirect
|
github.com/muesli/termenv v0.15.2 // indirect
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/spf13/pflag v1.0.5 // indirect
|
github.com/spf13/pflag v1.0.5 // indirect
|
||||||
golang.org/x/sync v0.1.0 // indirect
|
golang.org/x/sync v0.8.0 // indirect
|
||||||
golang.org/x/sys v0.14.0 // indirect
|
golang.org/x/sys v0.25.0 // indirect
|
||||||
golang.org/x/term v0.6.0 // indirect
|
golang.org/x/text v0.18.0 // indirect
|
||||||
golang.org/x/text v0.3.8 // indirect
|
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
|
||||||
)
|
)
|
||||||
|
91
go.sum
91
go.sum
@ -1,27 +1,31 @@
|
|||||||
github.com/alecthomas/assert/v2 v2.2.1 h1:XivOgYcduV98QCahG8T5XTezV5bylXe+lBxLG2K2ink=
|
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
|
||||||
github.com/alecthomas/assert/v2 v2.2.1/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ=
|
github.com/MakeNowJust/heredoc v1.0.0/go.mod h1:mG5amYoWBHf8vpLOuehzbGGw0EHxpZZ6lCpQ4fNJ8LE=
|
||||||
github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJWtdzJPs=
|
github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE=
|
||||||
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
github.com/alecthomas/chroma/v2 v2.14.0 h1:R3+wzpnUArGcQz7fCETQBzO5n9IMNi13iIs46aU4V9E=
|
||||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
github.com/alecthomas/chroma/v2 v2.14.0/go.mod h1:QolEbTfmUHIMVpBqxeDnNBj2uoeI4EbYP4i6n68SG4I=
|
||||||
|
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||||
|
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
|
github.com/charmbracelet/bubbles v0.20.0 h1:jSZu6qD8cRQ6k9OMfR1WlM+ruM8fkPWkHvQWD9LIutE=
|
||||||
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
|
github.com/charmbracelet/bubbles v0.20.0/go.mod h1:39slydyswPy+uVOHZ5x/GjwVAFkCsV8IIVy+4MhzwwU=
|
||||||
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
|
github.com/charmbracelet/bubbletea v1.1.1 h1:KJ2/DnmpfqFtDNVTvYZ6zpPFL9iRCRr0qqKOCvppbPY=
|
||||||
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
|
github.com/charmbracelet/bubbletea v1.1.1/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4=
|
||||||
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
|
github.com/charmbracelet/lipgloss v0.13.0 h1:4X3PPeoWEDCMvzDvGmTajSyYPcZM4+y8sCA/SsA3cjw=
|
||||||
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
|
github.com/charmbracelet/lipgloss v0.13.0/go.mod h1:nw4zy0SBX/F/eAO1cWdcvy6qnkDUxr8Lw7dvFrAIbbY=
|
||||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
|
github.com/charmbracelet/x/ansi v0.3.1 h1:CRO6lc/6HCx2/D6S/GZ87jDvRvk6GtPyFP+IljkNtqI=
|
||||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
|
github.com/charmbracelet/x/ansi v0.3.1/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0=
|
||||||
|
github.com/charmbracelet/x/term v0.2.0/go.mod h1:GVxgxAbjUrmpvIINHIQnJJKpMlHiZ4cktEQCN6GWyF0=
|
||||||
|
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
|
||||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
@ -36,17 +40,17 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
|||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
|
||||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||||
@ -61,31 +65,26 @@ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUc
|
|||||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
|
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
|
||||||
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
|
||||||
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
|
||||||
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
|
|
||||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
|
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
|
||||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||||
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
|
|
||||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
|
||||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
|
||||||
gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
|
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||||
gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
|
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||||
|
142
pkg/agents/toolbox/dir_tree.go
Normal file
142
pkg/agents/toolbox/dir_tree.go
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
package toolbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||||
|
|
||||||
|
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
||||||
|
|
||||||
|
Example result:
|
||||||
|
{
|
||||||
|
"message": "success",
|
||||||
|
"result": ".
|
||||||
|
├── a_directory/
|
||||||
|
│ ├── file1.txt (100 bytes)
|
||||||
|
│ └── file2.txt (200 bytes)
|
||||||
|
├── a_file.txt (123 bytes)
|
||||||
|
└── another_file.txt (456 bytes)"
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
var DirTreeTool = api.ToolSpec{
|
||||||
|
Name: "dir_tree",
|
||||||
|
Description: TREE_DESCRIPTION,
|
||||||
|
Parameters: []api.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "relative_path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "depth",
|
||||||
|
Type: "integer",
|
||||||
|
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
|
var relativeDir string
|
||||||
|
if tmp, ok := args["relative_path"]; ok {
|
||||||
|
relativeDir, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var depth int = 0 // Default value if not provided
|
||||||
|
if tmp, ok := args["depth"]; ok {
|
||||||
|
switch v := tmp.(type) {
|
||||||
|
case float64:
|
||||||
|
depth = int(v)
|
||||||
|
case string:
|
||||||
|
var err error
|
||||||
|
if depth, err = strconv.Atoi(v); err != nil {
|
||||||
|
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
||||||
|
}
|
||||||
|
depth = max(0, min(5, depth))
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := tree(relativeDir, depth)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func tree(path string, depth int) api.CallResult {
|
||||||
|
if path == "" {
|
||||||
|
path = "."
|
||||||
|
}
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return api.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
var treeOutput strings.Builder
|
||||||
|
treeOutput.WriteString(path + "\n")
|
||||||
|
err := buildTree(&treeOutput, path, "", depth)
|
||||||
|
if err != nil {
|
||||||
|
return api.CallResult{
|
||||||
|
Message: err.Error(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.CallResult{Result: treeOutput.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
||||||
|
files, err := os.ReadDir(path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, file := range files {
|
||||||
|
if strings.HasPrefix(file.Name(), ".") {
|
||||||
|
// Skip hidden files and directories
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
isLast := i == len(files)-1
|
||||||
|
var branch string
|
||||||
|
if isLast {
|
||||||
|
branch = "└── "
|
||||||
|
} else {
|
||||||
|
branch = "├── "
|
||||||
|
}
|
||||||
|
|
||||||
|
info, _ := file.Info()
|
||||||
|
size := info.Size()
|
||||||
|
sizeStr := fmt.Sprintf(" (%d bytes)", size)
|
||||||
|
|
||||||
|
output.WriteString(prefix + branch + file.Name())
|
||||||
|
if file.IsDir() {
|
||||||
|
output.WriteString("/\n")
|
||||||
|
if depth > 0 {
|
||||||
|
var nextPrefix string
|
||||||
|
if isLast {
|
||||||
|
nextPrefix = prefix + " "
|
||||||
|
} else {
|
||||||
|
nextPrefix = prefix + "│ "
|
||||||
|
}
|
||||||
|
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
output.WriteString(sizeStr + "\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,22 +1,22 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||||
|
|
||||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||||
|
|
||||||
var FileInsertLinesTool = model.Tool{
|
var FileInsertLinesTool = api.ToolSpec{
|
||||||
Name: "file_insert_lines",
|
Name: "file_insert_lines",
|
||||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -36,7 +36,7 @@ var FileInsertLinesTool = model.Tool{
|
|||||||
Required: true,
|
Required: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||||
@ -72,27 +72,27 @@ var FileInsertLinesTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func fileInsertLines(path string, position int, content string) model.CallResult {
|
func fileInsertLines(path string, position int, content string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the existing file's content
|
// Read the existing file's content
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
_, err = os.Create(path)
|
_, err = os.Create(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
}
|
}
|
||||||
data = []byte{}
|
data = []byte{}
|
||||||
}
|
}
|
||||||
|
|
||||||
if position < 1 {
|
if position < 1 {
|
||||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
return api.CallResult{Message: "start_line cannot be less than 1"}
|
||||||
}
|
}
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
lines := strings.Split(string(data), "\n")
|
||||||
@ -107,8 +107,8 @@ func fileInsertLines(path string, position int, content string) model.CallResult
|
|||||||
// Join the lines and write back to the file
|
// Join the lines and write back to the file
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: newContent}
|
return api.CallResult{Result: newContent}
|
||||||
}
|
}
|
178
pkg/agents/toolbox/modify_file.go
Normal file
178
pkg/agents/toolbox/modify_file.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package toolbox
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
var MODIFY_FILE_DESCRIPTION = []string{
|
||||||
|
"Modify a file. If the file does not exist, it will be created.",
|
||||||
|
"",
|
||||||
|
"Content can be either inserted, replaced, or removed through a combination of the start, stop, and content parameters.",
|
||||||
|
"Use the start and stop line numbers to limit the range of modification to the file.",
|
||||||
|
"If both `start` and `stop` are left unset (or set to 0), the entire file's contents will be updated.",
|
||||||
|
"If `start` is set to n and `stop` to n+1, content will be inserted at line n (the content that was at line n will be shifted below the newly inserted content).",
|
||||||
|
"If only `start` is set, content from the given line and onwards will be updated.",
|
||||||
|
"If only `stop` is set, content up to but not including the given line will be updated.",
|
||||||
|
"",
|
||||||
|
"Examples:",
|
||||||
|
"1. Append to a file:",
|
||||||
|
" {\"path\": \"example.txt\", \"start\": <last_line_number + 1>, \"content\": \"New content to append\"}",
|
||||||
|
"",
|
||||||
|
"2. Insert at a specific line:",
|
||||||
|
" {\"path\": \"example.txt\", \"start\": 5, \"stop\": 5, \"content\": \"New line inserted above the previous line 5\"}",
|
||||||
|
"",
|
||||||
|
"3. Replace a range of lines:",
|
||||||
|
" {\"path\": \"example.txt\", \"start\": 3, \"stop\": 7, \"content\": \"New content replacing lines 3-7\"}",
|
||||||
|
"",
|
||||||
|
"4. Remove a range of lines:",
|
||||||
|
" {\"path\": \"example.txt\", \"start\": 2, \"stop\": 5}",
|
||||||
|
"",
|
||||||
|
"5. Replace entire file contents:",
|
||||||
|
" {\"path\": \"example.txt\", \"content\": \"New file contents\"}",
|
||||||
|
"",
|
||||||
|
"6. Update from a specific line to the end of the file:",
|
||||||
|
" {\"path\": \"example.txt\", \"start\": 10, \"content\": \"New content from line 10 onwards\"}",
|
||||||
|
"",
|
||||||
|
"7. Update from the beginning of the file to a specific line:",
|
||||||
|
" {\"path\": \"example.txt\", \"stop\": 6, \"content\": \"New content for first 5 lines\"}",
|
||||||
|
"",
|
||||||
|
"Note: Always use specific line numbers based on the current file content. Avoid using arbitrarily large numbers for start or stop.",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ModifyFile = api.ToolSpec{
|
||||||
|
Name: "modify_file",
|
||||||
|
Description: strings.Join(MODIFY_FILE_DESCRIPTION, "\n"),
|
||||||
|
Parameters: []api.ToolParameter{
|
||||||
|
{
|
||||||
|
Name: "path",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||||
|
Required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "start",
|
||||||
|
Type: "integer",
|
||||||
|
Description: `Start line of the range to modify (inclusive). If omitted, the beginning of the file is implied.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "stop",
|
||||||
|
Type: "integer",
|
||||||
|
Description: `End line of the range to modify (inclusive). If omitted, the end of the file is implied.`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "content",
|
||||||
|
Type: "string",
|
||||||
|
Description: "Content to insert/replace at the range defined by `start` and `stop`. If omitted, the range is removed.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
|
tmp, ok := args["path"]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("path parameter to modify_file was not included.")
|
||||||
|
}
|
||||||
|
path, ok := tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
var start int
|
||||||
|
tmp, ok = args["start"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid start in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
start = int(tmp)
|
||||||
|
}
|
||||||
|
var stop int
|
||||||
|
tmp, ok = args["stop"]
|
||||||
|
if ok {
|
||||||
|
tmp, ok := tmp.(float64)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid stop in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
stop = int(tmp)
|
||||||
|
}
|
||||||
|
var content string
|
||||||
|
tmp, ok = args["content"]
|
||||||
|
if ok {
|
||||||
|
content, ok = tmp.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := fileModifyContents(path, start, stop, content)
|
||||||
|
ret, err := result.ToJson()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func fileModifyContents(path string, startLine int, stopLine int, content string) api.CallResult {
|
||||||
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
|
if !ok {
|
||||||
|
return api.CallResult{Message: reason}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the existing file's content
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
_, err = os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||||
|
}
|
||||||
|
data = []byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(data), "\n")
|
||||||
|
contentLines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
|
||||||
|
|
||||||
|
// If both start and stop are unset, update the entire file
|
||||||
|
if startLine == 0 && stopLine == 0 {
|
||||||
|
lines = contentLines
|
||||||
|
} else {
|
||||||
|
if startLine < 1 {
|
||||||
|
startLine = 1
|
||||||
|
}
|
||||||
|
if stopLine == 0 || stopLine > len(lines) {
|
||||||
|
stopLine = len(lines)
|
||||||
|
}
|
||||||
|
|
||||||
|
before := lines[:startLine-1]
|
||||||
|
after := lines[stopLine:]
|
||||||
|
|
||||||
|
// Handle insertion case
|
||||||
|
if startLine == stopLine {
|
||||||
|
lines = append(before, append(contentLines, lines[startLine-1:]...)...)
|
||||||
|
} else {
|
||||||
|
// If content is omitted, remove the specified range
|
||||||
|
if content == "" {
|
||||||
|
lines = append(before, after...)
|
||||||
|
} else {
|
||||||
|
lines = append(before, append(contentLines, after...)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newContent := strings.Join(lines, "\n")
|
||||||
|
|
||||||
|
// Write back to the file
|
||||||
|
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.CallResult{Result: util.AddLineNumbers(newContent)}
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -6,8 +6,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||||
@ -25,17 +25,17 @@ Example result:
|
|||||||
For files, size represents the size of the file, in bytes.
|
For files, size represents the size of the file, in bytes.
|
||||||
For directories, size represents the number of entries in that directory.`
|
For directories, size represents the number of entries in that directory.`
|
||||||
|
|
||||||
var ReadDirTool = model.Tool{
|
var ReadDirTool = api.ToolSpec{
|
||||||
Name: "read_dir",
|
Name: "read_dir",
|
||||||
Description: READ_DIR_DESCRIPTION,
|
Description: READ_DIR_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "relative_dir",
|
Name: "relative_dir",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
Description: "If set, read the contents of a directory relative to the current one.",
|
Description: "If set, read the contents of a directory relative to the current one.",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
var relativeDir string
|
var relativeDir string
|
||||||
tmp, ok := args["relative_dir"]
|
tmp, ok := args["relative_dir"]
|
||||||
if ok {
|
if ok {
|
||||||
@ -53,18 +53,18 @@ var ReadDirTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func readDir(path string) model.CallResult {
|
func readDir(path string) api.CallResult {
|
||||||
if path == "" {
|
if path == "" {
|
||||||
path = "."
|
path = "."
|
||||||
}
|
}
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
|
|
||||||
files, err := os.ReadDir(path)
|
files, err := os.ReadDir(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{
|
return api.CallResult{
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,5 +96,5 @@ func readDir(path string) model.CallResult {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return model.CallResult{Result: dirContents}
|
return api.CallResult{Result: dirContents}
|
||||||
}
|
}
|
@ -1,15 +1,16 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||||
|
|
||||||
|
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
||||||
|
|
||||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||||
|
|
||||||
@ -19,10 +20,10 @@ Example result:
|
|||||||
"result": "1\tthe contents\n2\tof the file\n"
|
"result": "1\tthe contents\n2\tof the file\n"
|
||||||
}`
|
}`
|
||||||
|
|
||||||
var ReadFileTool = model.Tool{
|
var ReadFileTool = api.ToolSpec{
|
||||||
Name: "read_file",
|
Name: "read_file",
|
||||||
Description: READ_FILE_DESCRIPTION,
|
Description: READ_FILE_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -31,7 +32,7 @@ var ReadFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||||
@ -49,23 +50,16 @@ var ReadFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func readFile(path string) model.CallResult {
|
func readFile(path string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
data, err := os.ReadFile(path)
|
data, err := os.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
|
return api.CallResult{
|
||||||
lines := strings.Split(string(data), "\n")
|
Result: toolutil.AddLineNumbers(string(data)),
|
||||||
content := strings.Builder{}
|
|
||||||
for i, line := range lines {
|
|
||||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
|
||||||
}
|
|
||||||
|
|
||||||
return model.CallResult{
|
|
||||||
Result: content.String(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -65,3 +65,14 @@ func IsPathWithinCWD(path string) (bool, string) {
|
|||||||
}
|
}
|
||||||
return true, ""
|
return true, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddLineNumbers takes a string of content and returns a new string with line
|
||||||
|
// numbers prefixed
|
||||||
|
func AddLineNumbers(content string) string {
|
||||||
|
lines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
|
||||||
|
result := strings.Builder{}
|
||||||
|
for i, line := range lines {
|
||||||
|
result.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
@ -1,11 +1,11 @@
|
|||||||
package tools
|
package toolbox
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||||
@ -15,10 +15,10 @@ Example result:
|
|||||||
"message": "success"
|
"message": "success"
|
||||||
}`
|
}`
|
||||||
|
|
||||||
var WriteFileTool = model.Tool{
|
var WriteFileTool = api.ToolSpec{
|
||||||
Name: "write_file",
|
Name: "write_file",
|
||||||
Description: WRITE_FILE_DESCRIPTION,
|
Description: WRITE_FILE_DESCRIPTION,
|
||||||
Parameters: []model.ToolParameter{
|
Parameters: []api.ToolParameter{
|
||||||
{
|
{
|
||||||
Name: "path",
|
Name: "path",
|
||||||
Type: "string",
|
Type: "string",
|
||||||
@ -32,7 +32,7 @@ var WriteFileTool = model.Tool{
|
|||||||
Required: true,
|
Required: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||||
tmp, ok := args["path"]
|
tmp, ok := args["path"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||||
@ -58,14 +58,14 @@ var WriteFileTool = model.Tool{
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeFile(path string, content string) model.CallResult {
|
func writeFile(path string, content string) api.CallResult {
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||||
if !ok {
|
if !ok {
|
||||||
return model.CallResult{Message: reason}
|
return api.CallResult{Message: reason}
|
||||||
}
|
}
|
||||||
err := os.WriteFile(path, []byte(content), 0644)
|
err := os.WriteFile(path, []byte(content), 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||||
}
|
}
|
||||||
return model.CallResult{}
|
return api.CallResult{}
|
||||||
}
|
}
|
47
pkg/agents/tools.go
Normal file
47
pkg/agents/tools.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package agents
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
|
||||||
|
"dir_tree": toolbox.DirTreeTool,
|
||||||
|
"read_dir": toolbox.ReadDirTool,
|
||||||
|
"read_file": toolbox.ReadFileTool,
|
||||||
|
"modify_file": toolbox.ModifyFile,
|
||||||
|
"write_file": toolbox.WriteFileTool,
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
|
||||||
|
var toolResults []api.ToolResult
|
||||||
|
for _, call := range calls {
|
||||||
|
var tool *api.ToolSpec
|
||||||
|
for i := range available {
|
||||||
|
if available[i].Name == call.Name {
|
||||||
|
tool = &available[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tool == nil {
|
||||||
|
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
result, err := tool.Impl(tool, call.Parameters)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResult := api.ToolResult{
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
ToolName: call.Name,
|
||||||
|
Result: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, toolResult)
|
||||||
|
}
|
||||||
|
return toolResults, nil
|
||||||
|
}
|
126
pkg/api/api.go
Normal file
126
pkg/api/api.go
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MessageRole string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessageRoleSystem MessageRole = "system"
|
||||||
|
MessageRoleUser MessageRole = "user"
|
||||||
|
MessageRoleAssistant MessageRole = "assistant"
|
||||||
|
MessageRoleToolCall MessageRole = "tool_call"
|
||||||
|
MessageRoleToolResult MessageRole = "tool_result"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Content string // TODO: support multi-part messages
|
||||||
|
Role MessageRole
|
||||||
|
ToolCalls []ToolCall
|
||||||
|
ToolResults []ToolResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolSpec struct {
|
||||||
|
Name string
|
||||||
|
Description string
|
||||||
|
Parameters []ToolParameter
|
||||||
|
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"` // "string", "integer", "boolean"
|
||||||
|
Required bool `json:"required"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
ID string `json:"id" yaml:"-"`
|
||||||
|
Name string `json:"name" yaml:"tool"`
|
||||||
|
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResult struct {
|
||||||
|
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||||
|
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||||
|
Result string `json:"result,omitempty" yaml:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessageWithAssistant(content string) *Message {
|
||||||
|
return &Message{
|
||||||
|
Role: MessageRoleAssistant,
|
||||||
|
Content: content,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMessageWithToolCalls(content string, toolCalls []ToolCall) *Message {
|
||||||
|
return &Message{
|
||||||
|
Role: MessageRoleToolCall,
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageRole) IsAssistant() bool {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleAssistant, MessageRoleToolCall:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageRole) IsUser() bool {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleUser, MessageRoleToolResult:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageRole) IsSystem() bool {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleSystem:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||||
|
func (m MessageRole) FriendlyRole() string {
|
||||||
|
switch m {
|
||||||
|
case MessageRoleUser:
|
||||||
|
return "You"
|
||||||
|
case MessageRoleSystem:
|
||||||
|
return "System"
|
||||||
|
case MessageRoleAssistant:
|
||||||
|
return "Assistant"
|
||||||
|
case MessageRoleToolCall:
|
||||||
|
return "Tool Call"
|
||||||
|
case MessageRoleToolResult:
|
||||||
|
return "Tool Result"
|
||||||
|
default:
|
||||||
|
return string(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: remove this
|
||||||
|
type CallResult struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Result any `json:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r CallResult) ToJson() (string, error) {
|
||||||
|
if r.Message == "" {
|
||||||
|
// When message not supplied, assume success
|
||||||
|
r.Message = "success"
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
@ -3,8 +3,10 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -14,12 +16,34 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Open the chat interface",
|
Short: "Open the chat interface",
|
||||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
// TODO: implement jump-to-conversation logic
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
shortname := ""
|
if err != nil {
|
||||||
if len(args) == 1 {
|
return err
|
||||||
shortname = args[0]
|
|
||||||
}
|
}
|
||||||
err := tui.Launch(ctx, shortname)
|
|
||||||
|
var opts []tui.LaunchOption
|
||||||
|
|
||||||
|
list, err := cmd.Flags().GetBool("list")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !list && len(args) == 1 {
|
||||||
|
shortname := args[0]
|
||||||
|
if shortname != ""{
|
||||||
|
conv, err := cmdutil.LookupConversationE(ctx, shortname)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts = append(opts, tui.WithInitialConversation(conv))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if list {
|
||||||
|
opts = append(opts, tui.WithInitialView(shared.ViewConversations))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tui.Launch(ctx, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
@ -30,8 +54,13 @@ func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -l, --list
|
||||||
|
cmd.Flags().BoolP("list", "l", false, "View/manage conversations")
|
||||||
|
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -28,36 +27,12 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
clone, messageCnt, err := ctx.Conversations.CloneConversation(*toClone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
clone := &model.Conversation{
|
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
||||||
Title: toClone.Title + " - Clone",
|
|
||||||
}
|
|
||||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
|
||||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var errors []error
|
|
||||||
messageCnt := 0
|
|
||||||
for _, message := range messagesToCopy {
|
|
||||||
newMessage := message
|
|
||||||
newMessage.ConversationID = clone.ID
|
|
||||||
newMessage.ID = 0
|
|
||||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
|
||||||
errors = append(errors, err)
|
|
||||||
} else {
|
|
||||||
messageCnt++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
@ -65,7 +40,7 @@ func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return cmd
|
return cmd
|
||||||
|
110
pkg/cmd/cmd.go
110
pkg/cmd/cmd.go
@ -1,6 +1,8 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
@ -8,10 +10,6 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
systemPromptFile string
|
|
||||||
)
|
|
||||||
|
|
||||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||||
var root = &cobra.Command{
|
var root = &cobra.Command{
|
||||||
Use: "lmcli <command> [flags]",
|
Use: "lmcli <command> [flags]",
|
||||||
@ -23,58 +21,72 @@ func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
chatCmd := ChatCmd(ctx)
|
|
||||||
continueCmd := ContinueCmd(ctx)
|
|
||||||
cloneCmd := CloneCmd(ctx)
|
|
||||||
editCmd := EditCmd(ctx)
|
|
||||||
listCmd := ListCmd(ctx)
|
|
||||||
newCmd := NewCmd(ctx)
|
|
||||||
promptCmd := PromptCmd(ctx)
|
|
||||||
renameCmd := RenameCmd(ctx)
|
|
||||||
replyCmd := ReplyCmd(ctx)
|
|
||||||
retryCmd := RetryCmd(ctx)
|
|
||||||
rmCmd := RemoveCmd(ctx)
|
|
||||||
viewCmd := ViewCmd(ctx)
|
|
||||||
|
|
||||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
|
||||||
for _, cmd := range inputCmds {
|
|
||||||
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
|
||||||
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
|
||||||
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
|
||||||
})
|
|
||||||
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
|
||||||
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
|
||||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
|
||||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
|
||||||
}
|
|
||||||
|
|
||||||
root.AddCommand(
|
root.AddCommand(
|
||||||
chatCmd,
|
ChatCmd(ctx),
|
||||||
cloneCmd,
|
ContinueCmd(ctx),
|
||||||
continueCmd,
|
CloneCmd(ctx),
|
||||||
editCmd,
|
EditCmd(ctx),
|
||||||
listCmd,
|
ListCmd(ctx),
|
||||||
newCmd,
|
NewCmd(ctx),
|
||||||
promptCmd,
|
PromptCmd(ctx),
|
||||||
renameCmd,
|
RenameCmd(ctx),
|
||||||
replyCmd,
|
ReplyCmd(ctx),
|
||||||
retryCmd,
|
RetryCmd(ctx),
|
||||||
rmCmd,
|
RemoveCmd(ctx),
|
||||||
viewCmd,
|
ViewCmd(ctx),
|
||||||
)
|
)
|
||||||
|
|
||||||
return root
|
return root
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSystemPrompt(ctx *lmcli.Context) string {
|
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
|
||||||
if systemPromptFile != "" {
|
f := cmd.Flags()
|
||||||
content, err := util.ReadFileContents(systemPromptFile)
|
|
||||||
if err != nil {
|
// -m, --model
|
||||||
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
f.StringVarP(
|
||||||
}
|
ctx.Config.Defaults.Model, "model", "m",
|
||||||
return content
|
*ctx.Config.Defaults.Model, "Which model to generate a response with",
|
||||||
|
)
|
||||||
|
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
||||||
|
})
|
||||||
|
|
||||||
|
// -a, --agent
|
||||||
|
f.StringVarP(&ctx.Config.Defaults.Agent, "agent", "a", ctx.Config.Defaults.Agent, "Which agent to interact with")
|
||||||
|
cmd.RegisterFlagCompletionFunc("agent", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||||
|
return ctx.GetAgents(), cobra.ShellCompDirectiveDefault
|
||||||
|
})
|
||||||
|
|
||||||
|
// --max-length
|
||||||
|
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||||
|
// --temperature
|
||||||
|
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
|
||||||
|
|
||||||
|
// --system-prompt
|
||||||
|
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||||
|
// --system-prompt-file
|
||||||
|
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
|
||||||
|
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
|
||||||
|
f := cmd.Flags()
|
||||||
|
model, err := f.GetString("model")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error parsing --model: %w", err)
|
||||||
}
|
}
|
||||||
return *ctx.Config.Defaults.SystemPrompt
|
if model != "" && !slices.Contains(ctx.GetModels(), model) {
|
||||||
|
return fmt.Errorf("Unknown model: %s", model)
|
||||||
|
}
|
||||||
|
|
||||||
|
agent, err := f.GetString("agent")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error parsing --agent: %w", err)
|
||||||
|
}
|
||||||
|
if agent != "" && agent != "none" && !slices.Contains(ctx.GetAgents(), agent) {
|
||||||
|
return fmt.Errorf("Unknown agent: %s", agent)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||||
|
@ -4,9 +4,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,10 +23,15 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
shortName := args[0]
|
||||||
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||||
}
|
}
|
||||||
@ -36,7 +41,7 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
lastMessage := &messages[len(messages)-1]
|
lastMessage := &messages[len(messages)-1]
|
||||||
if lastMessage.Role != model.MessageRoleAssistant {
|
if lastMessage.Role != api.MessageRoleAssistant {
|
||||||
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,16 +49,16 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
fmt.Print(lastMessage.Content)
|
fmt.Print(lastMessage.Content)
|
||||||
|
|
||||||
// Submit the LLM request, allowing it to continue the last message
|
// Submit the LLM request, allowing it to continue the last message
|
||||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append the new response to the original message
|
// Append the new response to the original message
|
||||||
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
|
||||||
|
|
||||||
// Update the original message
|
// Update the original message
|
||||||
err = ctx.Store.UpdateMessage(lastMessage)
|
err = ctx.Conversations.UpdateMessage(lastMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("could not update the last message: %v", err)
|
return fmt.Errorf("could not update the last message: %v", err)
|
||||||
}
|
}
|
||||||
@ -65,8 +70,9 @@ func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,11 +22,11 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, _ := cmd.Flags().GetInt("offset")
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
@ -39,21 +39,7 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
desiredIdx := len(messages) - 1 - offset
|
desiredIdx := len(messages) - 1 - offset
|
||||||
|
toEdit := messages[desiredIdx]
|
||||||
// walk backwards through the conversation deleting messages until and
|
|
||||||
// including the last user message
|
|
||||||
toRemove := []model.Message{}
|
|
||||||
var toEdit *model.Message
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if i == desiredIdx {
|
|
||||||
toEdit = &messages[i]
|
|
||||||
}
|
|
||||||
toRemove = append(toRemove, messages[i])
|
|
||||||
messages = messages[:i]
|
|
||||||
if toEdit != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||||
switch newContents {
|
switch newContents {
|
||||||
@ -63,38 +49,51 @@ func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toEdit.Content = newContents
|
||||||
|
|
||||||
role, _ := cmd.Flags().GetString("role")
|
role, _ := cmd.Flags().GetString("role")
|
||||||
if role == "" {
|
if role != "" {
|
||||||
role = string(toEdit.Role)
|
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) {
|
||||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, message := range toRemove {
|
|
||||||
err = ctx.Store.DeleteMessage(&message)
|
|
||||||
if err != nil {
|
|
||||||
lmcli.Warn("Could not delete message: %v\n", err)
|
|
||||||
}
|
}
|
||||||
|
toEdit.Role = api.MessageRole(role)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
// Update the message in-place
|
||||||
ConversationID: conversation.ID,
|
inplace, _ := cmd.Flags().GetBool("in-place")
|
||||||
Role: model.MessageRole(role),
|
if inplace {
|
||||||
Content: newContents,
|
return ctx.Conversations.UpdateMessage(&toEdit)
|
||||||
})
|
}
|
||||||
return nil
|
|
||||||
|
// Otherwise, create a branch for the edited message
|
||||||
|
message, _, err := ctx.Conversations.CloneBranch(toEdit)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if desiredIdx > 0 {
|
||||||
|
// update selected reply
|
||||||
|
messages[desiredIdx-1].SelectedReply = message
|
||||||
|
err = ctx.Conversations.UpdateMessage(&messages[desiredIdx-1])
|
||||||
|
} else {
|
||||||
|
// update selected root
|
||||||
|
c.SelectedRoot = message
|
||||||
|
err = ctx.Conversations.UpdateConversation(c)
|
||||||
|
}
|
||||||
|
return err
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
|
||||||
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||||
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
|
cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)")
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
@ -21,9 +20,9 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "List conversations",
|
Short: "List conversations",
|
||||||
Long: `List conversations in order of recent activity`,
|
Long: `List conversations in order of recent activity`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
conversations, err := ctx.Store.Conversations()
|
list, err := ctx.Conversations.LoadConversationList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
return fmt.Errorf("Could not load conversations: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Category struct {
|
type Category struct {
|
||||||
@ -58,17 +57,12 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
all, _ := cmd.Flags().GetBool("all")
|
all, _ := cmd.Flags().GetBool("all")
|
||||||
|
|
||||||
for _, conversation := range conversations {
|
for _, item := range list.Items {
|
||||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
age := now.Sub(item.LastMessageAt)
|
||||||
if lastMessage == nil || err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
|
||||||
|
|
||||||
var category string
|
var category string
|
||||||
for _, c := range categories {
|
for _, c := range categories {
|
||||||
if messageAge < c.cutoff {
|
if age < c.cutoff {
|
||||||
category = c.name
|
category = c.name
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -76,14 +70,14 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
|
|
||||||
formatted := fmt.Sprintf(
|
formatted := fmt.Sprintf(
|
||||||
"%s - %s - %s",
|
"%s - %s - %s",
|
||||||
conversation.ShortName.String,
|
item.ShortName,
|
||||||
util.HumanTimeElapsedSince(messageAge),
|
util.HumanTimeElapsedSince(age),
|
||||||
conversation.Title,
|
item.Title,
|
||||||
)
|
)
|
||||||
|
|
||||||
categorized[category] = append(
|
categorized[category] = append(
|
||||||
categorized[category],
|
categorized[category],
|
||||||
ConversationLine{messageAge, formatted},
|
ConversationLine{age, formatted},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,14 +90,10 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
|
||||||
return int(a.timeSinceReply - b.timeSinceReply)
|
|
||||||
})
|
|
||||||
|
|
||||||
fmt.Printf("%s:\n", category.name)
|
fmt.Printf("%s:\n", category.name)
|
||||||
for _, conv := range conversationLines {
|
for _, conv := range conversationLines {
|
||||||
if conversationsPrinted >= count && !all {
|
if conversationsPrinted >= count && !all {
|
||||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
fmt.Printf("%d remaining conversation(s), use --all to view.\n", list.Total-conversationsPrinted)
|
||||||
break outer
|
break outer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,8 +105,8 @@ func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Flags().Bool("all", false, "Show all conversations")
|
cmd.Flags().BoolP("all", "a", false, "Show all conversations")
|
||||||
cmd.Flags().Int("count", LS_COUNT, "How many conversations to show")
|
cmd.Flags().IntP("count", "c", LS_COUNT, "How many conversations to show")
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,10 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,46 +16,42 @@ func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Start a new conversation",
|
Short: "Start a new conversation",
|
||||||
Long: `Start a new conversation with the Large Language Model.`,
|
Long: `Start a new conversation with the Large Language Model.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
if messageContents == "" {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
||||||
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation := &model.Conversation{}
|
messages := []conversation.Message{{
|
||||||
err := ctx.Store.SaveConversation(conversation)
|
Role: api.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
}}
|
||||||
|
|
||||||
|
conversation, messages, err := ctx.Conversations.StartConversation(messages...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []model.Message{
|
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: model.MessageRoleSystem,
|
|
||||||
Content: getSystemPrompt(ctx),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
ConversationID: conversation.ID,
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: messageContents,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
title, err := cmdutil.GenerateTitle(ctx, messages)
|
||||||
|
|
||||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
|
err = ctx.Conversations.UpdateConversation(conversation)
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,10 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,28 +16,29 @@ func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
Short: "Do a one-shot prompt",
|
Short: "Do a one-shot prompt",
|
||||||
Long: `Prompt the Large Language Model and get a response.`,
|
Long: `Prompt the Large Language Model and get a response.`,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
if message == "" {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
||||||
|
if input == "" {
|
||||||
return fmt.Errorf("No message was provided.")
|
return fmt.Errorf("No message was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
messages := []model.Message{
|
messages := []conversation.Message{{
|
||||||
{
|
Role: api.MessageRoleUser,
|
||||||
Role: model.MessageRoleSystem,
|
Content: input,
|
||||||
Content: getSystemPrompt(ctx),
|
}}
|
||||||
},
|
|
||||||
{
|
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: message,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
_, err = cmdutil.Prompt(ctx, messages, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -5,8 +5,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,14 +23,14 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
var toRemove []*model.Conversation
|
var toRemove []*conversation.Conversation
|
||||||
for _, shortName := range args {
|
for _, shortName := range args {
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
toRemove = append(toRemove, conversation)
|
toRemove = append(toRemove, conversation)
|
||||||
}
|
}
|
||||||
var errors []error
|
var errors []error
|
||||||
for _, c := range toRemove {
|
for _, c := range toRemove {
|
||||||
err := ctx.Store.DeleteConversation(c)
|
err := ctx.Conversations.DeleteConversation(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||||
}
|
}
|
||||||
@ -44,7 +44,7 @@ func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||||
var completions []string
|
var completions []string
|
||||||
outer:
|
outer:
|
||||||
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
for _, completion := range ctx.Conversations.ConversationShortNameCompletions(toComplete) {
|
||||||
parts := strings.Split(completion, "\t")
|
parts := strings.Split(completion, "\t")
|
||||||
for _, arg := range args {
|
for _, arg := range args {
|
||||||
if parts[0] == arg {
|
if parts[0] == arg {
|
||||||
|
@ -24,12 +24,17 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
var title string
|
||||||
|
|
||||||
generate, _ := cmd.Flags().GetBool("generate")
|
generate, _ := cmd.Flags().GetBool("generate")
|
||||||
var title string
|
|
||||||
if generate {
|
if generate {
|
||||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||||
|
}
|
||||||
|
title, err = cmdutil.GenerateTitle(ctx, messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||||
}
|
}
|
||||||
@ -41,9 +46,9 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
conversation.Title = title
|
conversation.Title = title
|
||||||
err = ctx.Store.SaveConversation(conversation)
|
err = ctx.Conversations.UpdateConversation(conversation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@ -52,7 +57,7 @@ func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,9 +3,10 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,17 +23,21 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||||
if reply == "" {
|
if reply == "" {
|
||||||
return fmt.Errorf("No reply was provided.")
|
return fmt.Errorf("No reply was provided.")
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
cmdutil.HandleConversationReply(ctx, c, true, conversation.Message{
|
||||||
ConversationID: conversation.ID,
|
Role: api.MessageRoleUser,
|
||||||
Role: model.MessageRoleUser,
|
|
||||||
Content: reply,
|
Content: reply,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@ -42,8 +47,10 @@ func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -3,9 +3,9 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "retry <conversation>",
|
Use: "retry <conversation>",
|
||||||
Short: "Retry the last user reply in a conversation",
|
Short: "Retry the last user reply in a conversation",
|
||||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
Long: `Prompt the conversation from the last user response.`,
|
||||||
Args: func(cmd *cobra.Command, args []string) error {
|
Args: func(cmd *cobra.Command, args []string) error {
|
||||||
argCount := 1
|
argCount := 1
|
||||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||||
@ -22,28 +22,44 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
err := validateGenerationFlags(ctx, cmd)
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// walk backwards through the conversation and delete messages, break
|
shortName := args[0]
|
||||||
// when we find the latest user response
|
c := cmdutil.LookupConversation(ctx, shortName)
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if messages[i].Role == model.MessageRoleUser {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
err = ctx.Store.DeleteMessage(&messages[i])
|
// Load the complete thread from the root message
|
||||||
if err != nil {
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("Could not retrieve messages for conversation: %s", c.Title)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
offset, _ := cmd.Flags().GetInt("offset")
|
||||||
|
if offset < 0 {
|
||||||
|
offset = -offset
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset > len(messages)-1 {
|
||||||
|
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
retryFromIdx := len(messages) - 1 - offset
|
||||||
|
|
||||||
|
// decrease retryFromIdx until we hit a user message
|
||||||
|
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||||
|
retryFromIdx--
|
||||||
|
}
|
||||||
|
|
||||||
|
if messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||||
|
return fmt.Errorf("No user messages to retry")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
|
||||||
|
|
||||||
|
// Start a new branch at the last user message
|
||||||
|
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||||
@ -51,8 +67,12 @@ func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cmd.Flags().Int("offset", 0, "Offset from the last message to retry from.")
|
||||||
|
|
||||||
|
applyGenerationFlags(ctx, cmd)
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
@ -2,42 +2,59 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
)
|
)
|
||||||
|
|
||||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
// Prompt prompts the configured the configured model and streams the response
|
||||||
// the response to stdout. Returns all model reply messages.
|
// to stdout. Returns all model reply messages.
|
||||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
func Prompt(ctx *lmcli.Context, messages []conversation.Message, callback func(conversation.Message)) (*api.Message, error) {
|
||||||
content := make(chan string) // receives the reponse from LLM
|
m, _, p, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
defer close(content)
|
|
||||||
|
|
||||||
// render all content received over the channel
|
|
||||||
go ShowDelayedContent(content)
|
|
||||||
|
|
||||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
params := provider.RequestParameters{
|
||||||
Model: *ctx.Config.Defaults.Model,
|
Model: m,
|
||||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||||
Temperature: *ctx.Config.Defaults.Temperature,
|
Temperature: *ctx.Config.Defaults.Temperature,
|
||||||
ToolBag: ctx.EnabledTools,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := completionProvider.CreateChatCompletionStream(
|
system := ctx.DefaultSystemPrompt()
|
||||||
context.Background(), requestParams, messages, callback, content,
|
|
||||||
|
agent := ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||||
|
if agent != nil {
|
||||||
|
if agent.SystemPrompt != "" {
|
||||||
|
system = agent.SystemPrompt
|
||||||
|
}
|
||||||
|
params.Toolbox = agent.Toolbox
|
||||||
|
}
|
||||||
|
|
||||||
|
if system != "" {
|
||||||
|
messages = conversation.ApplySystemPrompt(messages, system, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := make(chan provider.Chunk)
|
||||||
|
defer close(content)
|
||||||
|
|
||||||
|
// render the content received over the channel
|
||||||
|
go ShowDelayedContent(content)
|
||||||
|
|
||||||
|
reply, err := p.CreateChatCompletionStream(
|
||||||
|
context.Background(), params, conversation.MessagesToAPI(messages), content,
|
||||||
)
|
)
|
||||||
if response != "" {
|
|
||||||
|
if reply.Content != "" {
|
||||||
// there was some content, so break to a new line after it
|
// there was some content, so break to a new line after it
|
||||||
fmt.Println()
|
fmt.Println()
|
||||||
|
|
||||||
@ -46,85 +63,99 @@ func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callba
|
|||||||
err = nil
|
err = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return response, nil
|
return reply, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupConversation either returns the conversation found by the
|
// lookupConversation either returns the conversation found by the
|
||||||
// short name or exits the program
|
// short name or exits the program
|
||||||
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
func LookupConversation(ctx *lmcli.Context, shortName string) *conversation.Conversation {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
||||||
}
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
func LookupConversationE(ctx *lmcli.Context, shortName string) (*conversation.Conversation, error) {
|
||||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
c, err := ctx.Conversations.FindConversationByShortName(shortName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||||
}
|
}
|
||||||
if c.ID == 0 {
|
if c.ID == 0 {
|
||||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleConversationReply(ctx *lmcli.Context, c *conversation.Conversation, persist bool, toSend ...conversation.Message) {
|
||||||
|
messages, err := ctx.Conversations.PathToLeaf(c.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
|
}
|
||||||
|
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
||||||
|
}
|
||||||
|
|
||||||
// handleConversationReply handles sending messages to an existing
|
// handleConversationReply handles sending messages to an existing
|
||||||
// conversation, optionally persisting both the sent replies and responses.
|
// conversation, optionally persisting both the sent replies and responses.
|
||||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
func HandleReply(ctx *lmcli.Context, to *conversation.Message, persist bool, messages ...conversation.Message) {
|
||||||
existing, err := ctx.Store.Messages(c)
|
if to == nil {
|
||||||
if err != nil {
|
lmcli.Fatal("Can't prompt from an empty message.")
|
||||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if persist {
|
existing, err := ctx.Conversations.PathToRoot(to)
|
||||||
for _, message := range toSend {
|
if err != nil {
|
||||||
err = ctx.Store.SaveMessage(&message)
|
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||||
if err != nil {
|
}
|
||||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
|
||||||
}
|
RenderConversation(ctx, append(existing, messages...), true)
|
||||||
|
|
||||||
|
var savedReplies []conversation.Message
|
||||||
|
if persist && len(messages) > 0 {
|
||||||
|
savedReplies, err = ctx.Conversations.Reply(to, messages...)
|
||||||
|
if err != nil {
|
||||||
|
lmcli.Warn("Could not save messages: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allMessages := append(existing, toSend...)
|
|
||||||
|
|
||||||
RenderConversation(ctx, allMessages, true)
|
|
||||||
|
|
||||||
// render a message header with no contents
|
// render a message header with no contents
|
||||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
RenderMessage(ctx, (&conversation.Message{Role: api.MessageRoleAssistant}))
|
||||||
|
|
||||||
replyCallback := func(reply model.Message) {
|
var lastSavedMessage *conversation.Message
|
||||||
|
lastSavedMessage = to
|
||||||
|
if len(savedReplies) > 0 {
|
||||||
|
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
replyCallback := func(reply conversation.Message) {
|
||||||
if !persist {
|
if !persist {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
savedReplies, err = ctx.Conversations.Reply(lastSavedMessage, reply)
|
||||||
reply.ConversationID = c.ID
|
|
||||||
err = ctx.Store.SaveMessage(&reply)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Warn("Could not save reply: %v\n", err)
|
lmcli.Warn("Could not save reply: %v\n", err)
|
||||||
}
|
}
|
||||||
|
lastSavedMessage = &savedReplies[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
func FormatForExternalPrompt(messages []conversation.Message, system bool) string {
|
||||||
sb := strings.Builder{}
|
sb := strings.Builder{}
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
if message.Content == "" {
|
if message.Content == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch message.Role {
|
switch message.Role {
|
||||||
case model.MessageRoleAssistant, model.MessageRoleToolCall:
|
case api.MessageRoleAssistant, api.MessageRoleToolCall:
|
||||||
sb.WriteString("Assistant:\n\n")
|
sb.WriteString("Assistant:\n\n")
|
||||||
case model.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
sb.WriteString("User:\n\n")
|
sb.WriteString("User:\n\n")
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
@ -134,60 +165,76 @@ func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
func GenerateTitle(ctx *lmcli.Context, messages []conversation.Message) (string, error) {
|
||||||
messages, err := ctx.Store.Messages(c)
|
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
|
||||||
|
|
||||||
Example conversation:
|
Example conversation:
|
||||||
|
|
||||||
"""
|
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
|
||||||
User:
|
|
||||||
|
|
||||||
Hello!
|
|
||||||
|
|
||||||
Assistant:
|
|
||||||
|
|
||||||
Hello! How may I assist you?
|
|
||||||
"""
|
|
||||||
|
|
||||||
Example response:
|
Example response:
|
||||||
|
|
||||||
"""
|
{"title": "Help with math homework"}
|
||||||
Title: A brief introduction
|
|
||||||
"""
|
|
||||||
`
|
`
|
||||||
conversation := FormatForExternalPrompt(messages, false)
|
type msg struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
generateRequest := []model.Message{
|
var msgs []msg
|
||||||
|
for _, m := range messages {
|
||||||
|
switch m.Role {
|
||||||
|
case api.MessageRoleAssistant, api.MessageRoleUser:
|
||||||
|
msgs = append(msgs, msg{string(m.Role), m.Content})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize the conversation to JSON
|
||||||
|
jsonBytes, err := json.Marshal(msgs)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
generateRequest := []conversation.Message{
|
||||||
{
|
{
|
||||||
Role: model.MessageRoleUser,
|
Role: api.MessageRoleSystem,
|
||||||
Content: fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n%s", conversation, prompt),
|
Content: systemPrompt,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: api.MessageRoleUser,
|
||||||
|
Content: string(jsonBytes),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
m, _, p, err := ctx.GetModelProvider(
|
||||||
|
*ctx.Config.Conversations.TitleGenerationModel, "",
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
requestParams := model.RequestParameters{
|
requestParams := provider.RequestParameters{
|
||||||
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
Model: m,
|
||||||
MaxTokens: 25,
|
MaxTokens: 25,
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
response, err := p.CreateChatCompletion(
|
||||||
|
context.Background(), requestParams, conversation.MessagesToAPI(generateRequest),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
response = strings.TrimPrefix(response, "Title: ")
|
// Parse the JSON response
|
||||||
response = strings.Trim(response, "\"")
|
var jsonResponse struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
return response, nil
|
return jsonResponse.Title, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
||||||
@ -227,7 +274,7 @@ func ShowWaitAnimation(signal chan any) {
|
|||||||
// chunked) content is received on the channel, the waiting animation is
|
// chunked) content is received on the channel, the waiting animation is
|
||||||
// replaced by the content.
|
// replaced by the content.
|
||||||
// Blocks until the channel is closed.
|
// Blocks until the channel is closed.
|
||||||
func ShowDelayedContent(content <-chan string) {
|
func ShowDelayedContent(content <-chan provider.Chunk) {
|
||||||
waitSignal := make(chan any)
|
waitSignal := make(chan any)
|
||||||
go ShowWaitAnimation(waitSignal)
|
go ShowWaitAnimation(waitSignal)
|
||||||
|
|
||||||
@ -240,14 +287,14 @@ func ShowDelayedContent(content <-chan string) {
|
|||||||
<-waitSignal
|
<-waitSignal
|
||||||
firstChunk = false
|
firstChunk = false
|
||||||
}
|
}
|
||||||
fmt.Print(chunk)
|
fmt.Print(chunk.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenderConversation renders the given messages to TTY, with optional space
|
// RenderConversation renders the given messages to TTY, with optional space
|
||||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||||
// are printed immediately after the final message (1 if false, 2 if true)
|
// are printed immediately after the final message (1 if false, 2 if true)
|
||||||
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
func RenderConversation(ctx *lmcli.Context, messages []conversation.Message, spaceForResponse bool) {
|
||||||
l := len(messages)
|
l := len(messages)
|
||||||
for i, message := range messages {
|
for i, message := range messages {
|
||||||
RenderMessage(ctx, &message)
|
RenderMessage(ctx, &message)
|
||||||
@ -258,7 +305,7 @@ func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForRe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
func RenderMessage(ctx *lmcli.Context, m *conversation.Message) {
|
||||||
var messageAge string
|
var messageAge string
|
||||||
if m.CreatedAt.IsZero() {
|
if m.CreatedAt.IsZero() {
|
||||||
messageAge = "now"
|
messageAge = "now"
|
||||||
@ -270,11 +317,11 @@ func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
|||||||
headerStyle := lipgloss.NewStyle().Bold(true)
|
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
|
||||||
switch m.Role {
|
switch m.Role {
|
||||||
case model.MessageRoleSystem:
|
case api.MessageRoleSystem:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||||
case model.MessageRoleUser:
|
case api.MessageRoleUser:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||||
case model.MessageRoleAssistant:
|
case api.MessageRoleAssistant:
|
||||||
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,13 +20,13 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
shortName := args[0]
|
shortName := args[0]
|
||||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||||
|
|
||||||
messages, err := ctx.Store.Messages(conversation)
|
messages, err := ctx.Conversations.PathToLeaf(conversation.SelectedRoot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmdutil.RenderConversation(ctx, messages, false)
|
cmdutil.RenderConversation(ctx, messages, false)
|
||||||
@ -37,7 +37,7 @@ func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
|||||||
if len(args) != 0 {
|
if len(args) != 0 {
|
||||||
return nil, compMode
|
return nil, compMode
|
||||||
}
|
}
|
||||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
return ctx.Conversations.ConversationShortNameCompletions(toComplete), compMode
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
99
pkg/conversation/conversation.go
Normal file
99
pkg/conversation/conversation.go
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Conversation struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
ShortName sql.NullString
|
||||||
|
Title string
|
||||||
|
SelectedRootID *uint
|
||||||
|
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||||
|
RootMessages []Message `gorm:"-:all"`
|
||||||
|
LastMessageAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageMeta struct {
|
||||||
|
GenerationProvider *string `json:"generation_provider,omitempty"`
|
||||||
|
GenerationModel *string `json:"generation_model,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint `gorm:"primaryKey"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
Metadata MessageMeta
|
||||||
|
|
||||||
|
ConversationID *uint `gorm:"index"`
|
||||||
|
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||||
|
ParentID *uint
|
||||||
|
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||||
|
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||||
|
SelectedReplyID *uint
|
||||||
|
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||||
|
|
||||||
|
Role api.MessageRole
|
||||||
|
Content string
|
||||||
|
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||||
|
ToolResults ToolResults // a json array of tool results
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MessageMeta) Scan(value interface{}) error {
|
||||||
|
return json.Unmarshal(value.([]byte), m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MessageMeta) Value() (driver.Value, error) {
|
||||||
|
return json.Marshal(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCalls []api.ToolCall
|
||||||
|
|
||||||
|
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tc = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tc)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||||
|
if len(tc) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(tc)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolResults []api.ToolResult
|
||||||
|
|
||||||
|
func (tr *ToolResults) Scan(value any) (err error) {
|
||||||
|
s := value.(string)
|
||||||
|
if value == nil || s == "" {
|
||||||
|
*tr = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = json.Unmarshal([]byte(s), tr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr ToolResults) Value() (driver.Value, error) {
|
||||||
|
if len(tr) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal([]api.ToolResult(tr))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||||
|
}
|
||||||
|
return string(jsonBytes), nil
|
||||||
|
}
|
520
pkg/conversation/repo.go
Normal file
520
pkg/conversation/repo.go
Normal file
@ -0,0 +1,520 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
sqids "github.com/sqids/sqids-go"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Repo exposes low-level message and conversation management. See
|
||||||
|
// Service for high-level helpers
|
||||||
|
type Repo interface {
|
||||||
|
LoadConversationList() (ConversationList, error)
|
||||||
|
|
||||||
|
FindConversationByShortName(shortName string) (*Conversation, error)
|
||||||
|
ConversationShortNameCompletions(search string) []string
|
||||||
|
GetConversationByID(int uint) (*Conversation, error)
|
||||||
|
GetRootMessages(conversationID uint) ([]Message, error)
|
||||||
|
|
||||||
|
CreateConversation(title string) (*Conversation, error)
|
||||||
|
UpdateConversation(*Conversation) error
|
||||||
|
DeleteConversation(*Conversation) error
|
||||||
|
DeleteConversationById(id uint) error
|
||||||
|
|
||||||
|
GetMessageByID(messageID uint) (*Message, error)
|
||||||
|
|
||||||
|
SaveMessage(message Message) (*Message, error)
|
||||||
|
UpdateMessage(message *Message) error
|
||||||
|
DeleteMessage(message *Message, prune bool) error
|
||||||
|
CloneBranch(toClone Message) (*Message, uint, error)
|
||||||
|
Reply(to *Message, messages ...Message) ([]Message, error)
|
||||||
|
|
||||||
|
PathToRoot(message *Message) ([]Message, error)
|
||||||
|
PathToLeaf(message *Message) ([]Message, error)
|
||||||
|
|
||||||
|
// Retrieves and return the "selected thread" of the conversation.
|
||||||
|
// The "selected thread" of the conversation is a chain of messages
|
||||||
|
// starting from the Conversation's SelectedRoot Message, following each
|
||||||
|
// Message's SelectedReply until the tail Message is reached.
|
||||||
|
GetSelectedThread(*Conversation) ([]Message, error)
|
||||||
|
|
||||||
|
// Start a new conversation with the given messages
|
||||||
|
StartConversation(messages ...Message) (*Conversation, []Message, error)
|
||||||
|
|
||||||
|
CloneConversation(toClone Conversation) (*Conversation, uint, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type repo struct {
|
||||||
|
db *gorm.DB
|
||||||
|
sqids *sqids.Sqids
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRepo(db *gorm.DB) (Repo, error) {
|
||||||
|
models := []any{
|
||||||
|
&Conversation{},
|
||||||
|
&Message{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, x := range models {
|
||||||
|
err := db.AutoMigrate(x)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||||
|
return &repo{db, _sqids}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConversationListItem struct {
|
||||||
|
ID uint
|
||||||
|
ShortName string
|
||||||
|
Title string
|
||||||
|
LastMessageAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConversationList struct {
|
||||||
|
Total int
|
||||||
|
Items []ConversationListItem
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConversationList loads existing conversations, ordered by the date
|
||||||
|
// of their latest message, from most recent to oldest.
|
||||||
|
func (s *repo) LoadConversationList() (ConversationList, error) {
|
||||||
|
list := ConversationList{}
|
||||||
|
|
||||||
|
var convos []Conversation
|
||||||
|
err := s.db.Order("last_message_at DESC").Find(&convos).Error
|
||||||
|
if err != nil {
|
||||||
|
return list, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range convos {
|
||||||
|
list.Items = append(list.Items, ConversationListItem{
|
||||||
|
ID: c.ID,
|
||||||
|
ShortName: c.ShortName.String,
|
||||||
|
Title: c.Title,
|
||||||
|
LastMessageAt: c.LastMessageAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
list.Total = len(list.Items)
|
||||||
|
return list, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) FindConversationByShortName(shortName string) (*Conversation, error) {
|
||||||
|
if shortName == "" {
|
||||||
|
return nil, errors.New("shortName is empty")
|
||||||
|
}
|
||||||
|
var conversation Conversation
|
||||||
|
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||||
|
return &conversation, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) ConversationShortNameCompletions(shortName string) []string {
|
||||||
|
var conversations []Conversation
|
||||||
|
// ignore error for completions
|
||||||
|
s.db.Find(&conversations)
|
||||||
|
completions := make([]string, 0, len(conversations))
|
||||||
|
for _, conversation := range conversations {
|
||||||
|
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||||
|
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return completions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) GetConversationByID(id uint) (*Conversation, error) {
|
||||||
|
var conversation Conversation
|
||||||
|
err := s.db.Preload("SelectedRoot").Where("id = ?", id).Find(&conversation).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Cannot get conversation %d: %v", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootMessages, err := s.GetRootMessages(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not load conversation's root messages %d: %v", id, err)
|
||||||
|
}
|
||||||
|
conversation.RootMessages = rootMessages
|
||||||
|
return &conversation, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) CreateConversation(title string) (*Conversation, error) {
|
||||||
|
// Create the new conversation
|
||||||
|
c := &Conversation{Title: title}
|
||||||
|
err := s.db.Create(c).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Generate and save its "short name"
|
||||||
|
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||||
|
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||||
|
err = s.db.Updates(c).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) UpdateConversation(c *Conversation) error {
|
||||||
|
if c == nil || c.ID == 0 {
|
||||||
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
|
return s.db.Updates(c).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) DeleteConversation(c *Conversation) error {
|
||||||
|
if c == nil || c.ID == 0 {
|
||||||
|
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
|
return s.DeleteConversationById(c.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) DeleteConversationById(id uint) error {
|
||||||
|
if id == 0 {
|
||||||
|
return fmt.Errorf("Invalid conversation ID: %d", id)
|
||||||
|
}
|
||||||
|
err := s.db.Where("conversation_id = ?", id).Delete(&Message{}).Error
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.db.Where("id = ?", id).Delete(&Conversation{}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) SaveMessage(m Message) (*Message, error) {
|
||||||
|
if m.Conversation == nil {
|
||||||
|
return nil, fmt.Errorf("Can't save a message without a conversation (this is a bug)")
|
||||||
|
}
|
||||||
|
newMessage := m
|
||||||
|
newMessage.ID = 0
|
||||||
|
newMessage.CreatedAt = time.Now()
|
||||||
|
return &newMessage, s.db.Create(&newMessage).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) UpdateMessage(m *Message) error {
|
||||||
|
if m == nil || m.ID == 0 {
|
||||||
|
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||||
|
}
|
||||||
|
return s.db.Updates(m).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) DeleteMessage(message *Message, prune bool) error {
|
||||||
|
return s.db.Delete(&message).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) GetMessageByID(messageID uint) (*Message, error) {
|
||||||
|
var message Message
|
||||||
|
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||||
|
return &message, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reply to a message with a series of messages (each followed by the next)
|
||||||
|
func (s *repo) Reply(to *Message, messages ...Message) ([]Message, error) {
|
||||||
|
var savedMessages []Message
|
||||||
|
|
||||||
|
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||||
|
currentParent := to
|
||||||
|
for i := range messages {
|
||||||
|
parent := currentParent
|
||||||
|
message := messages[i]
|
||||||
|
message.Parent = parent
|
||||||
|
message.Conversation = parent.Conversation
|
||||||
|
message.ID = 0
|
||||||
|
message.CreatedAt = time.Time{}
|
||||||
|
|
||||||
|
if err := tx.Create(&message).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// update parent selected reply
|
||||||
|
parent.Replies = append(parent.Replies, message)
|
||||||
|
parent.SelectedReply = &message
|
||||||
|
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
savedMessages = append(savedMessages, message)
|
||||||
|
currentParent = &message
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return savedMessages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
to.Conversation.LastMessageAt = savedMessages[len(savedMessages)-1].CreatedAt
|
||||||
|
err = s.UpdateConversation(to.Conversation)
|
||||||
|
return savedMessages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||||
|
// a new message object. The new message will be attached to the same parent as
|
||||||
|
// the messageToClone
|
||||||
|
func (s *repo) CloneBranch(messageToClone Message) (*Message, uint, error) {
|
||||||
|
newMessage := messageToClone
|
||||||
|
newMessage.ID = 0
|
||||||
|
newMessage.Replies = nil
|
||||||
|
newMessage.SelectedReplyID = nil
|
||||||
|
newMessage.SelectedReply = nil
|
||||||
|
|
||||||
|
originalReplies := messageToClone.Replies
|
||||||
|
|
||||||
|
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var replyCount uint = 0
|
||||||
|
for _, reply := range originalReplies {
|
||||||
|
replyCount++
|
||||||
|
|
||||||
|
newReply := reply
|
||||||
|
newReply.ConversationID = messageToClone.ConversationID
|
||||||
|
newReply.ParentID = &newMessage.ID
|
||||||
|
newReply.Parent = &newMessage
|
||||||
|
|
||||||
|
res, c, err := s.CloneBranch(newReply)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
newMessage.Replies = append(newMessage.Replies, *res)
|
||||||
|
replyCount += c
|
||||||
|
|
||||||
|
if reply.ID == *messageToClone.SelectedReplyID {
|
||||||
|
newMessage.SelectedReplyID = &res.ID
|
||||||
|
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &newMessage, replyCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchMessages(db *gorm.DB) ([]Message, error) {
|
||||||
|
var messages []Message
|
||||||
|
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
messageMap := make(map[uint]Message)
|
||||||
|
for i, message := range messages {
|
||||||
|
messageMap[messages[i].ID] = message
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a map to store replies by their parent ID
|
||||||
|
repliesMap := make(map[uint][]Message)
|
||||||
|
for i, message := range messages {
|
||||||
|
if messages[i].ParentID != nil {
|
||||||
|
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assign replies, parent, and selected reply to each message
|
||||||
|
for i := range messages {
|
||||||
|
if replies, exists := repliesMap[messages[i].ID]; exists {
|
||||||
|
messages[i].Replies = make([]Message, len(replies))
|
||||||
|
for j, m := range replies {
|
||||||
|
messages[i].Replies[j] = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if messages[i].ParentID != nil {
|
||||||
|
if parent, exists := messageMap[*messages[i].ParentID]; exists {
|
||||||
|
messages[i].Parent = &parent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if messages[i].SelectedReplyID != nil {
|
||||||
|
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
|
||||||
|
messages[i].SelectedReply = &selectedReply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r repo) GetRootMessages(conversationID uint) ([]Message, error) {
|
||||||
|
var rootMessages []Message
|
||||||
|
err := r.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not retrieve root messages for conversation %d: %v", conversationID, err)
|
||||||
|
}
|
||||||
|
return rootMessages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) buildPath(message *Message, getNext func(*Message) *uint) ([]Message, error) {
|
||||||
|
var messages []Message
|
||||||
|
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a map to store messages by their ID
|
||||||
|
messageMap := make(map[uint]*Message, len(messages))
|
||||||
|
for i := range messages {
|
||||||
|
messageMap[messages[i].ID] = &messages[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct Replies
|
||||||
|
repliesMap := make(map[uint][]*Message, len(messages))
|
||||||
|
for _, m := range messageMap {
|
||||||
|
if m.ParentID == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p, ok := messageMap[*m.ParentID]; ok {
|
||||||
|
repliesMap[p.ID] = append(repliesMap[p.ID], m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add replies to messages
|
||||||
|
for _, m := range messageMap {
|
||||||
|
if replies, ok := repliesMap[m.ID]; ok {
|
||||||
|
m.Replies = make([]Message, len(replies))
|
||||||
|
for idx, reply := range replies {
|
||||||
|
m.Replies[idx] = *reply
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the path
|
||||||
|
var path []Message
|
||||||
|
nextID := &message.ID
|
||||||
|
|
||||||
|
for {
|
||||||
|
current, exists := messageMap[*nextID]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
|
||||||
|
}
|
||||||
|
|
||||||
|
path = append(path, *current)
|
||||||
|
|
||||||
|
nextID = getNext(current)
|
||||||
|
if nextID == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToRoot traverses the provided message's Parent until reaching the tree
|
||||||
|
// root and returns a slice of all messages traversed in chronological order
|
||||||
|
// (starting with the root and ending with the message provided)
|
||||||
|
func (s *repo) PathToRoot(message *Message) ([]Message, error) {
|
||||||
|
if message == nil || message.ID <= 0 {
|
||||||
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
path, err := s.buildPath(message, func(m *Message) *uint {
|
||||||
|
return m.ParentID
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
slices.Reverse(path)
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
||||||
|
// tree leaf and returns a slice of all messages traversed in chronological
|
||||||
|
// order (starting with the message provided and ending with the leaf)
|
||||||
|
func (s *repo) PathToLeaf(message *Message) ([]Message, error) {
|
||||||
|
if message == nil || message.ID <= 0 {
|
||||||
|
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.buildPath(message, func(m *Message) *uint {
|
||||||
|
return m.SelectedReplyID
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) StartConversation(messages ...Message) (*Conversation, []Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new conversation
|
||||||
|
conversation, err := s.CreateConversation("")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages[0].Conversation = conversation
|
||||||
|
|
||||||
|
// Create first message
|
||||||
|
firstMessage, err := s.SaveMessage(messages[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages[0] = *firstMessage
|
||||||
|
|
||||||
|
// Update conversation's selected root message
|
||||||
|
conversation.RootMessages = []Message{messages[0]}
|
||||||
|
conversation.SelectedRoot = &messages[0]
|
||||||
|
conversation.LastMessageAt = messages[0].CreatedAt
|
||||||
|
|
||||||
|
// Add additional replies to conversation
|
||||||
|
if len(messages) > 1 {
|
||||||
|
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
messages = append([]Message{messages[0]}, newMessages...)
|
||||||
|
conversation.LastMessageAt = messages[len(messages)-1].CreatedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.UpdateConversation(conversation)
|
||||||
|
return conversation, messages, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneConversation clones the given conversation and all of its meesages
|
||||||
|
func (s *repo) CloneConversation(toClone Conversation) (*Conversation, uint, error) {
|
||||||
|
rootMessages, err := s.GetRootMessages(toClone.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clone, err := s.CreateConversation(toClone.Title + " - Clone")
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("Could not create clone: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors []error
|
||||||
|
var messageCnt uint = 0
|
||||||
|
for _, root := range rootMessages {
|
||||||
|
messageCnt++
|
||||||
|
newRoot := root
|
||||||
|
newRoot.ConversationID = &clone.ID
|
||||||
|
|
||||||
|
cloned, count, err := s.CloneBranch(newRoot)
|
||||||
|
if err != nil {
|
||||||
|
errors = append(errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messageCnt += count
|
||||||
|
|
||||||
|
if root.ID == *toClone.SelectedRootID {
|
||||||
|
clone.SelectedRootID = &cloned.ID
|
||||||
|
if err := s.UpdateConversation(clone); err != nil {
|
||||||
|
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errors) > 0 {
|
||||||
|
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
return clone, messageCnt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *repo) GetSelectedThread(c *Conversation) ([]Message, error) {
|
||||||
|
if c.SelectedRoot == nil {
|
||||||
|
return nil, fmt.Errorf("No SelectedRoot on conversation - this is a bug")
|
||||||
|
}
|
||||||
|
return s.PathToLeaf(c.SelectedRoot)
|
||||||
|
}
|
55
pkg/conversation/tools.go
Normal file
55
pkg/conversation/tools.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package conversation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplySystemPrompt updates the contents of an existing system Message if it
|
||||||
|
// exists, or returns a new slice with the system Message prepended.
|
||||||
|
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||||
|
if len(m) > 0 && m[0].Role == api.MessageRoleSystem {
|
||||||
|
if force {
|
||||||
|
m[0].Content = system
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
} else {
|
||||||
|
return append([]Message{{
|
||||||
|
Role: api.MessageRoleSystem,
|
||||||
|
Content: system,
|
||||||
|
}}, m...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessageToAPI(m Message) api.Message {
|
||||||
|
return api.Message{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
ToolCalls: m.ToolCalls,
|
||||||
|
ToolResults: m.ToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessagesToAPI(messages []Message) []api.Message {
|
||||||
|
ret := make([]api.Message, 0, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
ret = append(ret, MessageToAPI(m))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessageFromAPI(m api.Message) Message {
|
||||||
|
return Message{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
ToolCalls: m.ToolCalls,
|
||||||
|
ToolResults: m.ToolResults,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func MessagesFromAPI(messages []api.Message) []Message {
|
||||||
|
ret := make([]Message, 0, len(messages))
|
||||||
|
for _, m := range messages {
|
||||||
|
ret = append(ret, MessageFromAPI(m))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
@ -5,34 +5,39 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"github.com/go-yaml/yaml"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Defaults *struct {
|
Defaults *struct {
|
||||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
Model *string `yaml:"model" default:"gpt-4"`
|
||||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||||
Temperature *float32 `yaml:"temperature" default:"0.7"`
|
Temperature *float32 `yaml:"temperature" default:"0.2"`
|
||||||
Model *string `yaml:"model" default:"gpt-4"`
|
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
||||||
|
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
||||||
|
Agent string `yaml:"agent"`
|
||||||
} `yaml:"defaults"`
|
} `yaml:"defaults"`
|
||||||
Conversations *struct {
|
Conversations *struct {
|
||||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||||
} `yaml:"conversations"`
|
} `yaml:"conversations"`
|
||||||
Tools *struct {
|
|
||||||
EnabledTools *[]string `yaml:"enabledTools"`
|
|
||||||
} `yaml:"tools"`
|
|
||||||
OpenAI *struct {
|
|
||||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
|
||||||
Models *[]string `yaml:"models"`
|
|
||||||
} `yaml:"openai"`
|
|
||||||
Anthropic *struct {
|
|
||||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
|
||||||
Models *[]string `yaml:"models"`
|
|
||||||
} `yaml:"anthropic"`
|
|
||||||
Chroma *struct {
|
Chroma *struct {
|
||||||
Style *string `yaml:"style" default:"onedark"`
|
Style *string `yaml:"style" default:"onedark"`
|
||||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||||
} `yaml:"chroma"`
|
} `yaml:"chroma"`
|
||||||
|
Agents []*struct {
|
||||||
|
Name string `yaml:"name"`
|
||||||
|
SystemPrompt string `yaml:"systemPrompt"`
|
||||||
|
Tools []string `yaml:"tools"`
|
||||||
|
} `yaml:"agents"`
|
||||||
|
Providers []*struct {
|
||||||
|
Name string `yaml:"name,omitempty"`
|
||||||
|
Display string `yaml:"display,omitempty"`
|
||||||
|
Kind string `yaml:"kind"`
|
||||||
|
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||||
|
APIKey string `yaml:"apiKey,omitempty"`
|
||||||
|
Models []string `yaml:"models"`
|
||||||
|
Headers map[string]string `yaml:"headers"`
|
||||||
|
} `yaml:"providers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConfig(configFile string) (*Config, error) {
|
func NewConfig(configFile string) (*Config, error) {
|
||||||
@ -60,8 +65,9 @@ func NewConfig(configFile string) (*Config, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||||
}
|
}
|
||||||
bytes, _ := yaml.Marshal(c)
|
encoder := yaml.NewEncoder(file)
|
||||||
_, err = file.Write(bytes)
|
encoder.SetIndent(2)
|
||||||
|
err = encoder.Encode(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1,86 +1,231 @@
|
|||||||
package lmcli
|
package lmcli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/anthropic"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
"git.mlow.ca/mlow/lmcli/pkg/provider/google"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider/ollama"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider/openai"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Context struct {
|
type Agent struct {
|
||||||
Config *Config
|
Name string
|
||||||
Store ConversationStore
|
SystemPrompt string
|
||||||
|
Toolbox []api.ToolSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
type Context struct {
|
||||||
|
// high level app configuration, may be mutated at runtime
|
||||||
|
Config Config
|
||||||
|
Conversations conversation.Repo
|
||||||
Chroma *tty.ChromaHighlighter
|
Chroma *tty.ChromaHighlighter
|
||||||
EnabledTools []model.Tool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContext() (*Context, error) {
|
func NewContext() (*Context, error) {
|
||||||
configFile := filepath.Join(configDir(), "config.yaml")
|
configFile := filepath.Join(configDir(), "config.yaml")
|
||||||
config, err := NewConfig(configFile)
|
config, err := NewConfig(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
Fatal("%v\n", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
store, err := getConversationService()
|
||||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
return nil, err
|
||||||
}
|
|
||||||
store, err := NewSQLStore(db)
|
|
||||||
if err != nil {
|
|
||||||
Fatal("%v\n", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||||
|
return &Context{*config, store, chroma}, nil
|
||||||
|
}
|
||||||
|
|
||||||
var enabledTools []model.Tool
|
func createOrOpenAppend(path string) (*os.File, error) {
|
||||||
for _, toolName := range *config.Tools.EnabledTools {
|
var file *os.File
|
||||||
tool, ok := tools.AvailableTools[toolName]
|
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
|
||||||
if ok {
|
file, err = os.Create(path)
|
||||||
enabledTools = append(enabledTools, tool)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
file, err = os.OpenFile(path, os.O_APPEND, fs.ModeAppend)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return file, nil
|
||||||
|
}
|
||||||
|
|
||||||
return &Context{config, store, chroma, enabledTools}, nil
|
func getConversationService() (conversation.Repo, error) {
|
||||||
|
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||||
|
gormLogFile, err := createOrOpenAppend(filepath.Join(dataDir(), "database.log"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not open database log file: %v", err)
|
||||||
|
}
|
||||||
|
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||||
|
Logger: logger.New(log.New(gormLogFile, "\n", log.LstdFlags), logger.Config{
|
||||||
|
SlowThreshold: 200 * time.Millisecond,
|
||||||
|
LogLevel: logger.Info,
|
||||||
|
IgnoreRecordNotFoundError: false,
|
||||||
|
Colorful: true,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||||
|
}
|
||||||
|
repo, err := conversation.NewRepo(db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return repo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetModels() (models []string) {
|
func (c *Context) GetModels() (models []string) {
|
||||||
for _, m := range *c.Config.Anthropic.Models {
|
modelCounts := make(map[string]int)
|
||||||
models = append(models, m)
|
for _, p := range c.Config.Providers {
|
||||||
|
name := p.Kind
|
||||||
|
if p.Name != "" {
|
||||||
|
name = p.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range p.Models {
|
||||||
|
modelCounts[m]++
|
||||||
|
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for _, m := range *c.Config.OpenAI.Models {
|
|
||||||
models = append(models, m)
|
for m, c := range modelCounts {
|
||||||
|
if c == 1 {
|
||||||
|
models = append(models, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetAgents() (agents []string) {
|
||||||
|
for _, p := range c.Config.Agents {
|
||||||
|
agents = append(agents, p.Name)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
func (c *Context) GetAgent(name string) *Agent {
|
||||||
for _, m := range *c.Config.Anthropic.Models {
|
if name == "" || name == "none" {
|
||||||
if m == model {
|
return nil
|
||||||
anthropic := &anthropic.AnthropicClient{
|
}
|
||||||
APIKey: *c.Config.Anthropic.APIKey,
|
|
||||||
|
for _, a := range c.Config.Agents {
|
||||||
|
if name != a.Name {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var enabledTools []api.ToolSpec
|
||||||
|
for _, toolName := range a.Tools {
|
||||||
|
tool, ok := agents.AvailableTools[toolName]
|
||||||
|
if ok {
|
||||||
|
enabledTools = append(enabledTools, tool)
|
||||||
}
|
}
|
||||||
return anthropic, nil
|
}
|
||||||
|
|
||||||
|
return &Agent{
|
||||||
|
Name: a.Name,
|
||||||
|
SystemPrompt: a.SystemPrompt,
|
||||||
|
Toolbox: enabledTools,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, m := range *c.Config.OpenAI.Models {
|
return nil
|
||||||
if m == model {
|
}
|
||||||
openai := &openai.OpenAIClient{
|
|
||||||
APIKey: *c.Config.OpenAI.APIKey,
|
func (c *Context) DefaultSystemPrompt() string {
|
||||||
|
if c.Config.Defaults.SystemPromptFile != "" {
|
||||||
|
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
|
||||||
|
if err != nil {
|
||||||
|
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
return c.Config.Defaults.SystemPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) GetModelProvider(model string, provider string) (string, string, provider.ChatCompletionProvider, error) {
|
||||||
|
parts := strings.Split(model, "@")
|
||||||
|
|
||||||
|
if provider == "" && len(parts) > 1 {
|
||||||
|
model = parts[0]
|
||||||
|
provider = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range c.Config.Providers {
|
||||||
|
name := p.Kind
|
||||||
|
if p.Name != "" {
|
||||||
|
name = p.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider != "" && name != provider {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range p.Models {
|
||||||
|
if m == model {
|
||||||
|
switch p.Kind {
|
||||||
|
case "anthropic":
|
||||||
|
url := "https://api.anthropic.com"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
return model, name, &anthropic.AnthropicClient{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
}, nil
|
||||||
|
case "google":
|
||||||
|
url := "https://generativelanguage.googleapis.com"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
return model, name, &google.Client{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
}, nil
|
||||||
|
case "ollama":
|
||||||
|
url := "http://localhost:11434/api"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
return model, name, &ollama.OllamaClient{
|
||||||
|
BaseURL: url,
|
||||||
|
}, nil
|
||||||
|
case "openai":
|
||||||
|
url := "https://api.openai.com"
|
||||||
|
if p.BaseURL != "" {
|
||||||
|
url = p.BaseURL
|
||||||
|
}
|
||||||
|
return model, name, &openai.OpenAIClient{
|
||||||
|
BaseURL: url,
|
||||||
|
APIKey: p.APIKey,
|
||||||
|
Headers: p.Headers,
|
||||||
|
}, nil
|
||||||
|
default:
|
||||||
|
return "", "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return openai, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown model: %s", model)
|
return "", "", nil, fmt.Errorf("unknown model: %s", model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func configDir() string {
|
func configDir() string {
|
||||||
|
@ -1,58 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageRole string
|
|
||||||
|
|
||||||
const (
|
|
||||||
MessageRoleSystem MessageRole = "system"
|
|
||||||
MessageRoleUser MessageRole = "user"
|
|
||||||
MessageRoleAssistant MessageRole = "assistant"
|
|
||||||
MessageRoleToolCall MessageRole = "tool_call"
|
|
||||||
MessageRoleToolResult MessageRole = "tool_result"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
|
||||||
Content string
|
|
||||||
Role MessageRole
|
|
||||||
CreatedAt time.Time
|
|
||||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
|
||||||
ToolResults ToolResults // a json array of tool results
|
|
||||||
}
|
|
||||||
|
|
||||||
type Conversation struct {
|
|
||||||
ID uint `gorm:"primaryKey"`
|
|
||||||
ShortName sql.NullString
|
|
||||||
Title string
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestParameters struct {
|
|
||||||
Model string
|
|
||||||
MaxTokens int
|
|
||||||
Temperature float32
|
|
||||||
TopP float32
|
|
||||||
|
|
||||||
SystemPrompt string
|
|
||||||
ToolBag []Tool
|
|
||||||
}
|
|
||||||
|
|
||||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
|
||||||
func (m *MessageRole) FriendlyRole() string {
|
|
||||||
var friendlyRole string
|
|
||||||
switch *m {
|
|
||||||
case MessageRoleUser:
|
|
||||||
friendlyRole = "You"
|
|
||||||
case MessageRoleSystem:
|
|
||||||
friendlyRole = "System"
|
|
||||||
case MessageRoleAssistant:
|
|
||||||
friendlyRole = "Assistant"
|
|
||||||
default:
|
|
||||||
friendlyRole = string(*m)
|
|
||||||
}
|
|
||||||
return friendlyRole
|
|
||||||
}
|
|
@ -1,98 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Tool struct {
|
|
||||||
Name string
|
|
||||||
Description string
|
|
||||||
Parameters []ToolParameter
|
|
||||||
Impl func(*Tool, map[string]interface{}) (string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolParameter struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Type string `json:"type"` // "string", "integer", "boolean"
|
|
||||||
Required bool `json:"required"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCall struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Parameters map[string]interface{} `json:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolCalls []ToolCall
|
|
||||||
|
|
||||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
|
||||||
s := value.(string)
|
|
||||||
if value == nil || s == "" {
|
|
||||||
*tc = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = json.Unmarshal([]byte(s), tc)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
|
||||||
if len(tc) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(tc)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResult struct {
|
|
||||||
ToolCallID string `json:"toolCallID"`
|
|
||||||
ToolName string `json:"toolName,omitempty"`
|
|
||||||
Result string `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ToolResults []ToolResult
|
|
||||||
|
|
||||||
func (tr *ToolResults) Scan(value any) (err error) {
|
|
||||||
s := value.(string)
|
|
||||||
if value == nil || s == "" {
|
|
||||||
*tr = nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = json.Unmarshal([]byte(s), tr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tr ToolResults) Value() (driver.Value, error) {
|
|
||||||
if len(tr) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallResult struct {
|
|
||||||
Message string `json:"message"`
|
|
||||||
Result any `json:"result,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r CallResult) ToJson() (string, error) {
|
|
||||||
if r.Message == "" {
|
|
||||||
// When message not supplied, assume success
|
|
||||||
r.Message = "success"
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(r)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
|
||||||
}
|
|
||||||
return string(jsonBytes), nil
|
|
||||||
}
|
|
@ -1,348 +0,0 @@
|
|||||||
package anthropic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"encoding/xml"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AnthropicClient struct {
|
|
||||||
APIKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Request struct {
|
|
||||||
Model string `json:"model"`
|
|
||||||
Messages []Message `json:"messages"`
|
|
||||||
System string `json:"system,omitempty"`
|
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
|
||||||
Stream bool `json:"stream,omitempty"`
|
|
||||||
Temperature float32 `json:"temperature,omitempty"`
|
|
||||||
//TopP float32 `json:"top_p,omitempty"`
|
|
||||||
//TopK float32 `json:"top_k,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OriginalContent struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Text string `json:"text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Response struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Type string `json:"type"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
Content []OriginalContent `json:"content"`
|
|
||||||
}
|
|
||||||
|
|
||||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
|
||||||
|
|
||||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
|
||||||
requestBody := Request{
|
|
||||||
Model: params.Model,
|
|
||||||
Messages: make([]Message, len(messages)),
|
|
||||||
System: params.SystemPrompt,
|
|
||||||
MaxTokens: params.MaxTokens,
|
|
||||||
Temperature: params.Temperature,
|
|
||||||
Stream: false,
|
|
||||||
|
|
||||||
StopSequences: []string{
|
|
||||||
FUNCTION_STOP_SEQUENCE,
|
|
||||||
"\n\nHuman:",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
startIdx := 0
|
|
||||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
|
||||||
requestBody.System = messages[0].Content
|
|
||||||
requestBody.Messages = requestBody.Messages[1:]
|
|
||||||
startIdx = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
|
||||||
if len(requestBody.System) > 0 {
|
|
||||||
// add a divider between existing system prompt and tools
|
|
||||||
requestBody.System += "\n\n---\n\n"
|
|
||||||
}
|
|
||||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, msg := range messages[startIdx:] {
|
|
||||||
message := &requestBody.Messages[i]
|
|
||||||
|
|
||||||
switch msg.Role {
|
|
||||||
case model.MessageRoleToolCall:
|
|
||||||
message.Role = "assistant"
|
|
||||||
if msg.Content != "" {
|
|
||||||
message.Content = msg.Content
|
|
||||||
}
|
|
||||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
|
||||||
xmlString, err := xmlFuncCalls.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
|
||||||
}
|
|
||||||
if len(message.Content) > 0 {
|
|
||||||
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
|
||||||
} else {
|
|
||||||
message.Content = xmlString
|
|
||||||
}
|
|
||||||
case model.MessageRoleToolResult:
|
|
||||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
|
||||||
xmlString, err := xmlFuncResults.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
|
||||||
}
|
|
||||||
message.Role = "user"
|
|
||||||
message.Content = xmlString
|
|
||||||
default:
|
|
||||||
message.Role = string(msg.Role)
|
|
||||||
message.Content = msg.Content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return requestBody
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
|
||||||
url := "https://api.anthropic.com/v1/messages"
|
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header.Set("x-api-key", c.APIKey)
|
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
|
||||||
req.Header.Set("content-type", "application/json")
|
|
||||||
|
|
||||||
client := &http.Client{}
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletion(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
) (string, error) {
|
|
||||||
request := buildRequest(params, messages)
|
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var response Response
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sb := strings.Builder{}
|
|
||||||
for _, content := range response.Content {
|
|
||||||
var reply model.Message
|
|
||||||
switch content.Type {
|
|
||||||
case "text":
|
|
||||||
reply = model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.Text,
|
|
||||||
}
|
|
||||||
sb.WriteString(reply.Content)
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
|
||||||
}
|
|
||||||
if callback != nil {
|
|
||||||
callback(reply)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
output chan<- string,
|
|
||||||
) (string, error) {
|
|
||||||
request := buildRequest(params, messages)
|
|
||||||
request.Stream = true
|
|
||||||
|
|
||||||
resp, err := sendRequest(ctx, c, request)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
sb := strings.Builder{}
|
|
||||||
|
|
||||||
isToolCall := false
|
|
||||||
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
line = strings.TrimSpace(line)
|
|
||||||
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if line[0] == '{' {
|
|
||||||
var event map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(line), &event)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
|
||||||
}
|
|
||||||
eventType, ok := event["type"].(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("invalid event: %s", line)
|
|
||||||
}
|
|
||||||
switch eventType {
|
|
||||||
case "error":
|
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
||||||
default:
|
|
||||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(line, "data:") {
|
|
||||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
|
||||||
var event map[string]interface{}
|
|
||||||
err := json.Unmarshal([]byte(data), &event)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
eventType, ok := event["type"].(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("invalid event type")
|
|
||||||
}
|
|
||||||
|
|
||||||
switch eventType {
|
|
||||||
case "message_start":
|
|
||||||
// noop
|
|
||||||
case "ping":
|
|
||||||
// write an empty string to signal start of text
|
|
||||||
output <- ""
|
|
||||||
case "content_block_start":
|
|
||||||
// ignore?
|
|
||||||
case "content_block_delta":
|
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("invalid content block delta")
|
|
||||||
}
|
|
||||||
text, ok := delta["text"].(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("invalid text delta")
|
|
||||||
}
|
|
||||||
sb.WriteString(text)
|
|
||||||
output <- text
|
|
||||||
case "content_block_stop":
|
|
||||||
// ignore?
|
|
||||||
case "message_delta":
|
|
||||||
delta, ok := event["delta"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("invalid message delta")
|
|
||||||
}
|
|
||||||
stopReason, ok := delta["stop_reason"].(string)
|
|
||||||
if ok && stopReason == "stop_sequence" {
|
|
||||||
stopSequence, ok := delta["stop_sequence"].(string)
|
|
||||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
|
||||||
content := sb.String()
|
|
||||||
|
|
||||||
start := strings.Index(content, "<function_calls>")
|
|
||||||
if start == -1 {
|
|
||||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
|
||||||
}
|
|
||||||
|
|
||||||
isToolCall = true
|
|
||||||
|
|
||||||
funcCallXml := content[start:]
|
|
||||||
funcCallXml += FUNCTION_STOP_SEQUENCE
|
|
||||||
|
|
||||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
|
||||||
output <- FUNCTION_STOP_SEQUENCE
|
|
||||||
|
|
||||||
// Extract function calls
|
|
||||||
var functionCalls XMLFunctionCalls
|
|
||||||
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute function calls
|
|
||||||
toolCall := model.Message{
|
|
||||||
Role: model.MessageRoleToolCall,
|
|
||||||
// xml stripped from content
|
|
||||||
Content: content[:start],
|
|
||||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolReply := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(toolCall)
|
|
||||||
callback(toolReply)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
|
||||||
// added to the original messages
|
|
||||||
messages = append(append(messages, toolCall), toolReply)
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "message_stop":
|
|
||||||
// return the completed message
|
|
||||||
if callback != nil {
|
|
||||||
if !isToolCall {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: sb.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sb.String(), nil
|
|
||||||
case "error":
|
|
||||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
|
||||||
default:
|
|
||||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return "", fmt.Errorf("unexpected end of stream")
|
|
||||||
}
|
|
@ -1,230 +0,0 @@
|
|||||||
package anthropic
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
|
||||||
|
|
||||||
You may call them like this:
|
|
||||||
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>$TOOL_NAME</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
|
||||||
...
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
|
|
||||||
Here are the tools available:`
|
|
||||||
|
|
||||||
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
|
||||||
|
|
||||||
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
|
||||||
|
|
||||||
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
|
||||||
|
|
||||||
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
|
||||||
|
|
||||||
type XMLTools struct {
|
|
||||||
XMLName struct{} `xml:"tools"`
|
|
||||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLToolDescription struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Description string `xml:"description"`
|
|
||||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLToolParameter struct {
|
|
||||||
Name string `xml:"name"`
|
|
||||||
Type string `xml:"type"`
|
|
||||||
Description string `xml:"description"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionCalls struct {
|
|
||||||
XMLName struct{} `xml:"function_calls"`
|
|
||||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionInvoke struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionInvokeParameters struct {
|
|
||||||
String string `xml:",innerxml"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionResults struct {
|
|
||||||
XMLName struct{} `xml:"function_results"`
|
|
||||||
Result []XMLFunctionResult `xml:"result"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type XMLFunctionResult struct {
|
|
||||||
ToolName string `xml:"tool_name"`
|
|
||||||
Stdout string `xml:"stdout"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
|
||||||
// parameters name to value
|
|
||||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
|
||||||
lines := strings.Split(params, "\n")
|
|
||||||
ret := make(map[string]interface{}, len(lines))
|
|
||||||
for _, line := range lines {
|
|
||||||
i := strings.Index(line, ">")
|
|
||||||
if i == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
j := strings.Index(line, "</")
|
|
||||||
if j == -1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// chop from after opening < to first > to get parameter name,
|
|
||||||
// then chop after > to first </ to get parameter value
|
|
||||||
ret[line[1:i]] = line[i+1 : j]
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
|
||||||
converted := make([]XMLToolDescription, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
converted[i].ToolName = tool.Name
|
|
||||||
converted[i].Description = tool.Description
|
|
||||||
|
|
||||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
|
||||||
for j, param := range tool.Parameters {
|
|
||||||
params[j].Name = param.Name
|
|
||||||
params[j].Description = param.Description
|
|
||||||
params[j].Type = param.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
converted[i].Parameters = params
|
|
||||||
}
|
|
||||||
return XMLTools{
|
|
||||||
ToolDescriptions: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
|
||||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
|
||||||
for i, invoke := range functionCalls.Invoke {
|
|
||||||
toolCalls[i].Name = invoke.ToolName
|
|
||||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
|
||||||
}
|
|
||||||
return toolCalls
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
|
||||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
|
||||||
for i, toolCall := range toolCalls {
|
|
||||||
var params XMLFunctionInvokeParameters
|
|
||||||
var paramXML string
|
|
||||||
for key, value := range toolCall.Parameters {
|
|
||||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
|
||||||
}
|
|
||||||
params.String = paramXML
|
|
||||||
converted[i] = XMLFunctionInvoke{
|
|
||||||
ToolName: toolCall.Name,
|
|
||||||
Parameters: params,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return XMLFunctionCalls{
|
|
||||||
Invoke: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
|
||||||
converted := make([]XMLFunctionResult, len(toolResults))
|
|
||||||
for i, result := range toolResults {
|
|
||||||
converted[i].ToolName = result.ToolName
|
|
||||||
converted[i].Stdout = result.Result
|
|
||||||
}
|
|
||||||
return XMLFunctionResults{
|
|
||||||
Result: converted,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
|
||||||
xmlTools := convertToolsToXMLTools(tools)
|
|
||||||
xmlToolsString, err := xmlTools.XMLString()
|
|
||||||
if err != nil {
|
|
||||||
panic("Could not serialize []model.Tool to XMLTools")
|
|
||||||
}
|
|
||||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLTools) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("tools").Parse(`<tools>
|
|
||||||
{{range .ToolDescriptions}}<tool_description>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<description>
|
|
||||||
{{.Description}}
|
|
||||||
</description>
|
|
||||||
<parameters>
|
|
||||||
{{range .Parameters}}<parameter>
|
|
||||||
<name>{{.Name}}</name>
|
|
||||||
<type>{{.Type}}</type>
|
|
||||||
<description>{{.Description}}</description>
|
|
||||||
</parameter>
|
|
||||||
{{end}}</parameters>
|
|
||||||
</tool_description>
|
|
||||||
{{end}}</tools>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
|
||||||
{{range .Result}}<result>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<stdout>{{.Stdout}}</stdout>
|
|
||||||
</result>
|
|
||||||
{{end}}</function_results>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
|
||||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
|
||||||
{{range .Invoke}}<invoke>
|
|
||||||
<tool_name>{{.ToolName}}</tool_name>
|
|
||||||
<parameters>{{.Parameters.String}}</parameters>
|
|
||||||
</invoke>
|
|
||||||
{{end}}</function_calls>`)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := tmpl.Execute(&buf, x); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
@ -1,278 +0,0 @@
|
|||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
|
||||||
openai "github.com/sashabaranov/go-openai"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OpenAIClient struct {
|
|
||||||
APIKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIToolParameters struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
|
|
||||||
Required []string `json:"required,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIToolParameter struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertTools(tools []model.Tool) []openai.Tool {
|
|
||||||
openaiTools := make([]openai.Tool, len(tools))
|
|
||||||
for i, tool := range tools {
|
|
||||||
openaiTools[i].Type = "function"
|
|
||||||
|
|
||||||
params := make(map[string]OpenAIToolParameter)
|
|
||||||
var required []string
|
|
||||||
|
|
||||||
for _, param := range tool.Parameters {
|
|
||||||
params[param.Name] = OpenAIToolParameter{
|
|
||||||
Type: param.Type,
|
|
||||||
Description: param.Description,
|
|
||||||
Enum: param.Enum,
|
|
||||||
}
|
|
||||||
if param.Required {
|
|
||||||
required = append(required, param.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
openaiTools[i].Function = openai.FunctionDefinition{
|
|
||||||
Name: tool.Name,
|
|
||||||
Description: tool.Description,
|
|
||||||
Parameters: OpenAIToolParameters{
|
|
||||||
Type: "object",
|
|
||||||
Properties: params,
|
|
||||||
Required: required,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return openaiTools
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
|
|
||||||
converted := make([]openai.ToolCall, len(toolCalls))
|
|
||||||
for i, call := range toolCalls {
|
|
||||||
converted[i].Type = "function"
|
|
||||||
converted[i].ID = call.ID
|
|
||||||
converted[i].Function.Name = call.Name
|
|
||||||
|
|
||||||
json, _ := json.Marshal(call.Parameters)
|
|
||||||
converted[i].Function.Arguments = string(json)
|
|
||||||
}
|
|
||||||
return converted
|
|
||||||
}
|
|
||||||
|
|
||||||
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
|
|
||||||
converted := make([]model.ToolCall, len(toolCalls))
|
|
||||||
for i, call := range toolCalls {
|
|
||||||
converted[i].ID = call.ID
|
|
||||||
converted[i].Name = call.Function.Name
|
|
||||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
|
||||||
}
|
|
||||||
return converted
|
|
||||||
}
|
|
||||||
|
|
||||||
func createChatCompletionRequest(
|
|
||||||
c *OpenAIClient,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
) openai.ChatCompletionRequest {
|
|
||||||
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
|
|
||||||
|
|
||||||
for _, m := range messages {
|
|
||||||
switch m.Role {
|
|
||||||
case "tool_call":
|
|
||||||
message := openai.ChatCompletionMessage{}
|
|
||||||
message.Role = "assistant"
|
|
||||||
message.Content = m.Content
|
|
||||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
|
||||||
requestMessages = append(requestMessages, message)
|
|
||||||
case "tool_result":
|
|
||||||
// expand tool_result messages' results into multiple openAI messages
|
|
||||||
for _, result := range m.ToolResults {
|
|
||||||
message := openai.ChatCompletionMessage{}
|
|
||||||
message.Role = "tool"
|
|
||||||
message.Content = result.Result
|
|
||||||
message.ToolCallID = result.ToolCallID
|
|
||||||
requestMessages = append(requestMessages, message)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
message := openai.ChatCompletionMessage{}
|
|
||||||
message.Role = string(m.Role)
|
|
||||||
message.Content = m.Content
|
|
||||||
requestMessages = append(requestMessages, message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
request := openai.ChatCompletionRequest{
|
|
||||||
Model: params.Model,
|
|
||||||
MaxTokens: params.MaxTokens,
|
|
||||||
Temperature: params.Temperature,
|
|
||||||
Messages: requestMessages,
|
|
||||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(params.ToolBag) > 0 {
|
|
||||||
request.Tools = convertTools(params.ToolBag)
|
|
||||||
request.ToolChoice = "auto"
|
|
||||||
}
|
|
||||||
|
|
||||||
return request
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleToolCalls(
|
|
||||||
params model.RequestParameters,
|
|
||||||
content string,
|
|
||||||
toolCalls []openai.ToolCall,
|
|
||||||
) ([]model.Message, error) {
|
|
||||||
toolCall := model.Message{
|
|
||||||
Role: model.MessageRoleToolCall,
|
|
||||||
Content: content,
|
|
||||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.Message{
|
|
||||||
Role: model.MessageRoleToolResult,
|
|
||||||
ToolResults: toolResults,
|
|
||||||
}
|
|
||||||
|
|
||||||
return []model.Message{toolCall, toolResult}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletion(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
) (string, error) {
|
|
||||||
client := openai.NewClient(c.APIKey)
|
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
|
||||||
resp, err := client.CreateChatCompletion(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
choice := resp.Choices[0]
|
|
||||||
|
|
||||||
toolCalls := choice.Message.ToolCalls
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if callback != nil {
|
|
||||||
for _, result := range results {
|
|
||||||
callback(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletion with the tool call replies
|
|
||||||
messages = append(messages, results...)
|
|
||||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: choice.Message.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the user-facing message.
|
|
||||||
return choice.Message.Content, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback provider.ReplyCallback,
|
|
||||||
output chan<- string,
|
|
||||||
) (string, error) {
|
|
||||||
client := openai.NewClient(c.APIKey)
|
|
||||||
req := createChatCompletionRequest(c, params, messages)
|
|
||||||
|
|
||||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
content := strings.Builder{}
|
|
||||||
toolCalls := []openai.ToolCall{}
|
|
||||||
|
|
||||||
// Iterate stream segments
|
|
||||||
for {
|
|
||||||
response, e := stream.Recv()
|
|
||||||
if errors.Is(e, io.EOF) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if e != nil {
|
|
||||||
err = e
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
delta := response.Choices[0].Delta
|
|
||||||
if len(delta.ToolCalls) > 0 {
|
|
||||||
// Construct streamed tool_call arguments
|
|
||||||
for _, tc := range delta.ToolCalls {
|
|
||||||
if tc.Index == nil {
|
|
||||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
|
||||||
}
|
|
||||||
if len(toolCalls) <= *tc.Index {
|
|
||||||
toolCalls = append(toolCalls, tc)
|
|
||||||
} else {
|
|
||||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
output <- delta.Content
|
|
||||||
content.WriteString(delta.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(toolCalls) > 0 {
|
|
||||||
results, err := handleToolCalls(params, content.String(), toolCalls)
|
|
||||||
if err != nil {
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
if callback != nil {
|
|
||||||
for _, result := range results {
|
|
||||||
callback(result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
|
||||||
messages = append(messages, results...)
|
|
||||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
|
||||||
} else {
|
|
||||||
if callback != nil {
|
|
||||||
callback(model.Message{
|
|
||||||
Role: model.MessageRoleAssistant,
|
|
||||||
Content: content.String(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return content.String(), err
|
|
||||||
}
|
|
@ -1,31 +0,0 @@
|
|||||||
package provider
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ReplyCallback func(model.Message)
|
|
||||||
|
|
||||||
type ChatCompletionClient interface {
|
|
||||||
// CreateChatCompletion requests a response to the provided messages.
|
|
||||||
// Replies are appended to the given replies struct, and the
|
|
||||||
// complete user-facing response is returned as a string.
|
|
||||||
CreateChatCompletion(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback ReplyCallback,
|
|
||||||
) (string, error)
|
|
||||||
|
|
||||||
// Like CreateChageCompletion, except the response is streamed via
|
|
||||||
// the output channel as it's received.
|
|
||||||
CreateChatCompletionStream(
|
|
||||||
ctx context.Context,
|
|
||||||
params model.RequestParameters,
|
|
||||||
messages []model.Message,
|
|
||||||
callback ReplyCallback,
|
|
||||||
output chan<- string,
|
|
||||||
) (string, error)
|
|
||||||
}
|
|
@ -1,132 +0,0 @@
|
|||||||
package lmcli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
sqids "github.com/sqids/sqids-go"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ConversationStore interface {
|
|
||||||
Conversations() ([]model.Conversation, error)
|
|
||||||
|
|
||||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
|
||||||
ConversationShortNameCompletions(search string) []string
|
|
||||||
|
|
||||||
SaveConversation(conversation *model.Conversation) error
|
|
||||||
DeleteConversation(conversation *model.Conversation) error
|
|
||||||
|
|
||||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
|
||||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
|
||||||
|
|
||||||
SaveMessage(message *model.Message) error
|
|
||||||
DeleteMessage(message *model.Message) error
|
|
||||||
UpdateMessage(message *model.Message) error
|
|
||||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SQLStore struct {
|
|
||||||
db *gorm.DB
|
|
||||||
sqids *sqids.Sqids
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
|
||||||
models := []any{
|
|
||||||
&model.Conversation{},
|
|
||||||
&model.Message{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, x := range models {
|
|
||||||
err := db.AutoMigrate(x)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
|
||||||
return &SQLStore{db, _sqids}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
|
||||||
err := s.db.Save(&conversation).Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !conversation.ShortName.Valid {
|
|
||||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
|
||||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
|
||||||
err = s.db.Save(&conversation).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
|
||||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
|
||||||
return s.db.Delete(&conversation).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
|
||||||
return s.db.Create(message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
|
||||||
return s.db.Delete(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
|
||||||
return s.db.Updates(&message).Error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
|
||||||
var conversations []model.Conversation
|
|
||||||
err := s.db.Find(&conversations).Error
|
|
||||||
return conversations, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
|
||||||
var completions []string
|
|
||||||
conversations, _ := s.Conversations() // ignore error for completions
|
|
||||||
for _, conversation := range conversations {
|
|
||||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
|
||||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return completions
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
|
||||||
if shortName == "" {
|
|
||||||
return nil, errors.New("shortName is empty")
|
|
||||||
}
|
|
||||||
var conversation model.Conversation
|
|
||||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
|
||||||
return &conversation, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
|
||||||
var messages []model.Message
|
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
|
||||||
return messages, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
|
||||||
var message model.Message
|
|
||||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
|
||||||
return &message, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
|
||||||
// used to easily copy a message associated with one conversation, to another
|
|
||||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
|
||||||
m.ConversationID = c.ID
|
|
||||||
m.ID = 0
|
|
||||||
m.CreatedAt = time.Time{}
|
|
||||||
return &m, s.SaveMessage(&m)
|
|
||||||
}
|
|
@ -1,133 +0,0 @@
|
|||||||
package tools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
|
||||||
|
|
||||||
Useful for re-writing snippets/blocks of code or entire functions.
|
|
||||||
|
|
||||||
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
|
||||||
|
|
||||||
var FileReplaceLinesTool = model.Tool{
|
|
||||||
Name: "file_replace_lines",
|
|
||||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
|
||||||
Parameters: []model.ToolParameter{
|
|
||||||
{
|
|
||||||
Name: "path",
|
|
||||||
Type: "string",
|
|
||||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "start_line",
|
|
||||||
Type: "integer",
|
|
||||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "end_line",
|
|
||||||
Type: "integer",
|
|
||||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "content",
|
|
||||||
Type: "string",
|
|
||||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
|
||||||
tmp, ok := args["path"]
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
|
||||||
}
|
|
||||||
path, ok := tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
var start_line int
|
|
||||||
tmp, ok = args["start_line"]
|
|
||||||
if ok {
|
|
||||||
tmp, ok := tmp.(float64)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
start_line = int(tmp)
|
|
||||||
}
|
|
||||||
var end_line int
|
|
||||||
tmp, ok = args["end_line"]
|
|
||||||
if ok {
|
|
||||||
tmp, ok := tmp.(float64)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
end_line = int(tmp)
|
|
||||||
}
|
|
||||||
var content string
|
|
||||||
tmp, ok = args["content"]
|
|
||||||
if ok {
|
|
||||||
content, ok = tmp.(string)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result := fileReplaceLines(path, start_line, end_line, content)
|
|
||||||
ret, err := result.ToJson()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
|
||||||
}
|
|
||||||
return ret, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
|
||||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
|
||||||
if !ok {
|
|
||||||
return model.CallResult{Message: reason}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read the existing file's content
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
if !os.IsNotExist(err) {
|
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
|
||||||
}
|
|
||||||
_, err = os.Create(path)
|
|
||||||
if err != nil {
|
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
|
||||||
}
|
|
||||||
data = []byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if startLine < 1 {
|
|
||||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := strings.Split(string(data), "\n")
|
|
||||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
|
||||||
|
|
||||||
if endLine == 0 || endLine > len(lines) {
|
|
||||||
endLine = len(lines)
|
|
||||||
}
|
|
||||||
|
|
||||||
before := lines[:startLine-1]
|
|
||||||
after := lines[endLine:]
|
|
||||||
|
|
||||||
lines = append(before, append(contentLines, after...)...)
|
|
||||||
newContent := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
// Join the lines and write back to the file
|
|
||||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
|
||||||
if err != nil {
|
|
||||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
|
||||||
}
|
|
||||||
|
|
||||||
return model.CallResult{Result: newContent}
|
|
||||||
}
|
|
@ -1,47 +0,0 @@
|
|||||||
package tools
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
|
||||||
"read_dir": ReadDirTool,
|
|
||||||
"read_file": ReadFileTool,
|
|
||||||
"write_file": WriteFileTool,
|
|
||||||
"file_insert_lines": FileInsertLinesTool,
|
|
||||||
"file_replace_lines": FileReplaceLinesTool,
|
|
||||||
}
|
|
||||||
|
|
||||||
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
|
||||||
var toolResults []model.ToolResult
|
|
||||||
for _, toolCall := range toolCalls {
|
|
||||||
var tool *model.Tool
|
|
||||||
for _, available := range toolBag {
|
|
||||||
if available.Name == toolCall.Name {
|
|
||||||
tool = &available
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if tool == nil {
|
|
||||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the tool
|
|
||||||
result, err := tool.Impl(tool, toolCall.Parameters)
|
|
||||||
if err != nil {
|
|
||||||
// This can happen if the model missed or supplied invalid tool args
|
|
||||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult := model.ToolResult{
|
|
||||||
ToolCallID: toolCall.ID,
|
|
||||||
ToolName: toolCall.Name,
|
|
||||||
Result: result,
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResults = append(toolResults, toolResult)
|
|
||||||
}
|
|
||||||
return toolResults, nil
|
|
||||||
}
|
|
447
pkg/provider/anthropic/anthropic.go
Normal file
447
pkg/provider/anthropic/anthropic.go
Normal file
@ -0,0 +1,447 @@
|
|||||||
|
package anthropic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ANTHROPIC_VERSION = "2023-06-01"
|
||||||
|
|
||||||
|
type AnthropicClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content interface{} `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
InputSchema InputSchema `json:"input_schema"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type InputSchema struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]Property `json:"properties"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Property struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
System string `json:"system,omitempty"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentBlock struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
ID string `json:"id,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Input interface{} `json:"input,omitempty"`
|
||||||
|
partialJsonAccumulator string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Content []ContentBlock `json:"content"`
|
||||||
|
StopReason string `json:"stop_reason"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message interface{} `json:"message,omitempty"`
|
||||||
|
Index int `json:"index,omitempty"`
|
||||||
|
Delta interface{} `json:"delta,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
|
anthropicTools := make([]Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
properties := make(map[string]Property)
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
properties[param.Name] = Property{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Enum: param.Enum,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var required []string
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anthropicTools[i] = Tool{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
InputSchema: InputSchema{
|
||||||
|
Type: "object",
|
||||||
|
Properties: properties,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return anthropicTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChatCompletionRequest(
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (string, ChatCompletionRequest) {
|
||||||
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
|
var systemMessage string
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
if m.Role == api.MessageRoleSystem {
|
||||||
|
systemMessage = m.Content
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var content interface{}
|
||||||
|
role := string(m.Role)
|
||||||
|
|
||||||
|
switch m.Role {
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
role = "assistant"
|
||||||
|
contentBlocks := make([]map[string]interface{}, 0)
|
||||||
|
if m.Content != "" {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": m.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, toolCall := range m.ToolCalls {
|
||||||
|
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": toolCall.ID,
|
||||||
|
"name": toolCall.Name,
|
||||||
|
"input": toolCall.Parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
content = contentBlocks
|
||||||
|
|
||||||
|
case api.MessageRoleToolResult:
|
||||||
|
role = "user"
|
||||||
|
contentBlocks := make([]map[string]interface{}, 0)
|
||||||
|
for _, result := range m.ToolResults {
|
||||||
|
contentBlock := map[string]interface{}{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": result.ToolCallID,
|
||||||
|
"content": result.Result,
|
||||||
|
}
|
||||||
|
contentBlocks = append(contentBlocks, contentBlock)
|
||||||
|
}
|
||||||
|
content = contentBlocks
|
||||||
|
|
||||||
|
default:
|
||||||
|
content = m.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
requestMessages = append(requestMessages, ChatCompletionMessage{
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
request := ChatCompletionRequest{
|
||||||
|
Model: params.Model,
|
||||||
|
Messages: requestMessages,
|
||||||
|
System: systemMessage,
|
||||||
|
MaxTokens: params.MaxTokens,
|
||||||
|
Temperature: params.Temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.Toolbox) > 0 {
|
||||||
|
request.Tools = convertTools(params.Toolbox)
|
||||||
|
}
|
||||||
|
|
||||||
|
var prefill string
|
||||||
|
if len(messages) > 0 && messages[len(messages)-1].Role == api.MessageRoleAssistant {
|
||||||
|
// Prompting on an assitant message, use its content as prefill
|
||||||
|
prefill = messages[len(messages)-1].Content
|
||||||
|
}
|
||||||
|
|
||||||
|
return prefill, request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||||
|
jsonData, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("x-api-key", c.APIKey)
|
||||||
|
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
||||||
|
req.Header.Set("content-type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AnthropicClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, req := createChatCompletionRequest(params, messages)
|
||||||
|
req.Stream = false
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp ChatCompletionResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return convertResponseToMessage(completionResp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
output chan<- provider.Chunk,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
prefill, req := createChatCompletionRequest(params, messages)
|
||||||
|
req.Stream = true
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
contentBlocks := make(map[int]*ContentBlock)
|
||||||
|
var finalMessage *ChatCompletionResponse
|
||||||
|
|
||||||
|
var firstChunkReceived bool
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
|
||||||
|
var streamEvent StreamEvent
|
||||||
|
err = json.Unmarshal(line, &streamEvent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch streamEvent.Type {
|
||||||
|
case "message_start":
|
||||||
|
finalMessage = &ChatCompletionResponse{}
|
||||||
|
err = json.Unmarshal(line, &struct {
|
||||||
|
Message *ChatCompletionResponse `json:"message"`
|
||||||
|
}{Message: finalMessage})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
||||||
|
}
|
||||||
|
case "content_block_start":
|
||||||
|
var contentBlockStart struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
ContentBlock ContentBlock `json:"content_block"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(line, &contentBlockStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
||||||
|
case "content_block_delta":
|
||||||
|
if streamEvent.Index >= len(contentBlocks) {
|
||||||
|
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
block := contentBlocks[streamEvent.Index]
|
||||||
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
deltaType, ok := delta["type"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("delta missing type field")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deltaType {
|
||||||
|
case "text_delta":
|
||||||
|
if text, ok := delta["text"].(string); ok {
|
||||||
|
if !firstChunkReceived {
|
||||||
|
if prefill == "" {
|
||||||
|
// if there is no prefil, ensure we trim leading whitespace
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
firstChunkReceived = true
|
||||||
|
}
|
||||||
|
block.Text += text
|
||||||
|
output <- provider.Chunk{
|
||||||
|
Content: text,
|
||||||
|
// rough, anthropic performs some chunking
|
||||||
|
TokenCount: uint(len(strings.Split(text, " "))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "input_json_delta":
|
||||||
|
if block.Type != "tool_use" {
|
||||||
|
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
||||||
|
}
|
||||||
|
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||||
|
block.partialJsonAccumulator += partialJSON
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
if streamEvent.Index >= len(contentBlocks) {
|
||||||
|
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
block := contentBlocks[streamEvent.Index]
|
||||||
|
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
||||||
|
var inputData map[string]interface{}
|
||||||
|
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
||||||
|
}
|
||||||
|
block.Input = inputData
|
||||||
|
}
|
||||||
|
case "message_delta":
|
||||||
|
if finalMessage == nil {
|
||||||
|
return nil, fmt.Errorf("received message_delta before message_start")
|
||||||
|
}
|
||||||
|
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
||||||
|
}
|
||||||
|
if stopReason, ok := delta["stop_reason"].(string); ok {
|
||||||
|
finalMessage.StopReason = stopReason
|
||||||
|
}
|
||||||
|
|
||||||
|
case "message_stop":
|
||||||
|
// End of the stream
|
||||||
|
goto END_STREAM
|
||||||
|
|
||||||
|
case "error":
|
||||||
|
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Ignore unknown event types
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
END_STREAM:
|
||||||
|
if finalMessage == nil {
|
||||||
|
return nil, fmt.Errorf("no final message received")
|
||||||
|
}
|
||||||
|
|
||||||
|
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
||||||
|
for _, v := range contentBlocks {
|
||||||
|
finalMessage.Content = append(finalMessage.Content, *v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return convertResponseToMessage(*finalMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
||||||
|
content := strings.Builder{}
|
||||||
|
var toolCalls []api.ToolCall
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "text":
|
||||||
|
content.WriteString(block.Text)
|
||||||
|
case "tool_use":
|
||||||
|
parameters, ok := block.Input.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, api.ToolCall{
|
||||||
|
ID: block.ID,
|
||||||
|
Name: block.Name,
|
||||||
|
Parameters: parameters,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return api.NewMessageWithToolCalls(content.String(), toolCalls), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
|
}
|
436
pkg/provider/google/google.go
Normal file
436
pkg/provider/google/google.go
Normal file
@ -0,0 +1,436 @@
|
|||||||
|
package google
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ContentPart struct {
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Args map[string]string `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Response interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Content struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Parts []ContentPart `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerationConfig struct {
|
||||||
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||||
|
Temperature *float32 `json:"temperature,omitempty"`
|
||||||
|
TopP *float32 `json:"topP,omitempty"`
|
||||||
|
TopK *int `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentRequest struct {
|
||||||
|
Contents []Content `json:"contents"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||||
|
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Candidate struct {
|
||||||
|
Content Content `json:"content"`
|
||||||
|
FinishReason string `json:"finishReason"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GenerateContentResponse struct {
|
||||||
|
Candidates []Candidate `json:"candidates"`
|
||||||
|
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDeclaration struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Values []string `json:"values,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
|
geminiTools := make([]Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
params := make(map[string]ToolParameter)
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
// TODO: proper enum handing
|
||||||
|
params[param.Name] = ToolParameter{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Values: param.Enum,
|
||||||
|
}
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiTools[i] = Tool{
|
||||||
|
FunctionDeclarations: []FunctionDeclaration{
|
||||||
|
{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: ToolParameters{
|
||||||
|
Type: "OBJECT",
|
||||||
|
Properties: params,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return geminiTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
||||||
|
converted := make([]ContentPart, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
args := make(map[string]string)
|
||||||
|
for k, v := range call.Parameters {
|
||||||
|
args[k] = fmt.Sprintf("%v", v)
|
||||||
|
}
|
||||||
|
converted[i].FunctionCall = &FunctionCall{
|
||||||
|
Name: call.Name,
|
||||||
|
Args: args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
||||||
|
converted := make([]api.ToolCall, len(functionCalls))
|
||||||
|
for i, call := range functionCalls {
|
||||||
|
params := make(map[string]interface{})
|
||||||
|
for k, v := range call.Args {
|
||||||
|
params[k] = v
|
||||||
|
}
|
||||||
|
converted[i].Name = call.Name
|
||||||
|
converted[i].Parameters = params
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
||||||
|
results := make([]FunctionResponse, len(toolResults))
|
||||||
|
for i, result := range toolResults {
|
||||||
|
var obj interface{}
|
||||||
|
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||||
|
}
|
||||||
|
results[i] = FunctionResponse{
|
||||||
|
Name: result.ToolName,
|
||||||
|
Response: obj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createGenerateContentRequest(
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*GenerateContentRequest, error) {
|
||||||
|
requestContents := make([]Content, 0, len(messages))
|
||||||
|
|
||||||
|
startIdx := 0
|
||||||
|
var system string
|
||||||
|
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||||
|
system = messages[0].Content
|
||||||
|
startIdx = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range messages[startIdx:] {
|
||||||
|
switch m.Role {
|
||||||
|
case "tool_call":
|
||||||
|
content := Content{
|
||||||
|
Role: "model",
|
||||||
|
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
case "tool_result":
|
||||||
|
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// expand tool_result messages' results into multiple gemini messages
|
||||||
|
for _, result := range results {
|
||||||
|
content := Content{
|
||||||
|
Role: "function",
|
||||||
|
Parts: []ContentPart{
|
||||||
|
{
|
||||||
|
FunctionResp: &result,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var role string
|
||||||
|
switch m.Role {
|
||||||
|
case api.MessageRoleAssistant:
|
||||||
|
role = "model"
|
||||||
|
case api.MessageRoleUser:
|
||||||
|
role = "user"
|
||||||
|
}
|
||||||
|
|
||||||
|
if role == "" {
|
||||||
|
panic("Unhandled role: " + m.Role)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := Content{
|
||||||
|
Role: role,
|
||||||
|
Parts: []ContentPart{
|
||||||
|
{
|
||||||
|
Text: m.Content,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
requestContents = append(requestContents, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := &GenerateContentRequest{
|
||||||
|
Contents: requestContents,
|
||||||
|
GenerationConfig: &GenerationConfig{
|
||||||
|
MaxOutputTokens: ¶ms.MaxTokens,
|
||||||
|
Temperature: ¶ms.Temperature,
|
||||||
|
TopP: ¶ms.TopP,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if system != "" {
|
||||||
|
request.SystemInstruction = &Content{
|
||||||
|
Parts: []ContentPart{
|
||||||
|
{
|
||||||
|
Text: system,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.Toolbox) > 0 {
|
||||||
|
request.Tools = convertTools(params.Toolbox)
|
||||||
|
}
|
||||||
|
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp GenerateContentResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := completionResp.Candidates[0]
|
||||||
|
|
||||||
|
var content string
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content = lastMessage.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []FunctionCall
|
||||||
|
for _, part := range choice.Content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
content += part.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
output chan<- provider.Chunk,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := createGenerateContentRequest(params, messages)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf(
|
||||||
|
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||||
|
c.BaseURL, params.Model, c.APIKey,
|
||||||
|
)
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolCalls []FunctionCall
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
|
||||||
|
lastTokenCount := 0
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
|
||||||
|
var resp GenerateContentResponse
|
||||||
|
err = json.Unmarshal(line, &resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||||
|
lastTokenCount += tokens
|
||||||
|
|
||||||
|
choice := resp.Candidates[0]
|
||||||
|
for _, part := range choice.Content.Parts {
|
||||||
|
if part.FunctionCall != nil {
|
||||||
|
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||||
|
} else if part.Text != "" {
|
||||||
|
output <- provider.Chunk{
|
||||||
|
Content: part.Text,
|
||||||
|
TokenCount: uint(tokens),
|
||||||
|
}
|
||||||
|
content.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
|
}
|
183
pkg/provider/ollama/ollama.go
Normal file
183
pkg/provider/ollama/ollama.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
package ollama
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OllamaClient struct {
|
||||||
|
BaseURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
Messages []OllamaMessage `json:"messages"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OllamaResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
Message OllamaMessage `json:"message"`
|
||||||
|
Done bool `json:"done"`
|
||||||
|
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||||
|
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||||
|
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||||
|
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||||
|
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||||
|
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOllamaRequest(
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) OllamaRequest {
|
||||||
|
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
message := OllamaMessage{
|
||||||
|
Role: string(m.Role),
|
||||||
|
Content: m.Content,
|
||||||
|
}
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
request := OllamaRequest{
|
||||||
|
Model: params.Model,
|
||||||
|
Messages: requestMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createOllamaRequest(params, messages)
|
||||||
|
req.Stream = false
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp OllamaResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(completionResp.Message.Content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OllamaClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
output chan<- provider.Chunk,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createOllamaRequest(params, messages)
|
||||||
|
req.Stream = true
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var streamResp OllamaResponse
|
||||||
|
err = json.Unmarshal(line, &streamResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(streamResp.Message.Content) > 0 {
|
||||||
|
output <- provider.Chunk{
|
||||||
|
Content: streamResp.Message.Content,
|
||||||
|
TokenCount: 1,
|
||||||
|
}
|
||||||
|
content.WriteString(streamResp.Message.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
|
}
|
343
pkg/provider/openai/openai.go
Normal file
343
pkg/provider/openai/openai.go
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OpenAIClient struct {
|
||||||
|
APIKey string
|
||||||
|
BaseURL string
|
||||||
|
Headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionDefinition struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters ToolParameters `json:"parameters"`
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolParameter struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []string `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function FunctionDefinition `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float32 `json:"temperature,omitempty"`
|
||||||
|
Messages []ChatCompletionMessage `json:"messages"`
|
||||||
|
N int `json:"n"`
|
||||||
|
Tools []Tool `json:"tools,omitempty"`
|
||||||
|
ToolChoice string `json:"tool_choice,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionChoice struct {
|
||||||
|
Message ChatCompletionMessage `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
Choices []ChatCompletionChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamChoice struct {
|
||||||
|
Delta ChatCompletionMessage `json:"delta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionStreamResponse struct {
|
||||||
|
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertTools(tools []api.ToolSpec) []Tool {
|
||||||
|
openaiTools := make([]Tool, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
openaiTools[i].Type = "function"
|
||||||
|
|
||||||
|
params := make(map[string]ToolParameter)
|
||||||
|
var required []string
|
||||||
|
|
||||||
|
for _, param := range tool.Parameters {
|
||||||
|
params[param.Name] = ToolParameter{
|
||||||
|
Type: param.Type,
|
||||||
|
Description: param.Description,
|
||||||
|
Enum: param.Enum,
|
||||||
|
}
|
||||||
|
if param.Required {
|
||||||
|
required = append(required, param.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiTools[i].Function = FunctionDefinition{
|
||||||
|
Name: tool.Name,
|
||||||
|
Description: tool.Description,
|
||||||
|
Parameters: ToolParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: params,
|
||||||
|
Required: required,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return openaiTools
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
||||||
|
converted := make([]ToolCall, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
converted[i].Type = "function"
|
||||||
|
converted[i].ID = call.ID
|
||||||
|
converted[i].Function.Name = call.Name
|
||||||
|
|
||||||
|
json, _ := json.Marshal(call.Parameters)
|
||||||
|
converted[i].Function.Arguments = string(json)
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||||
|
converted := make([]api.ToolCall, len(toolCalls))
|
||||||
|
for i, call := range toolCalls {
|
||||||
|
converted[i].ID = call.ID
|
||||||
|
converted[i].Name = call.Function.Name
|
||||||
|
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||||
|
}
|
||||||
|
return converted
|
||||||
|
}
|
||||||
|
|
||||||
|
func createChatCompletionRequest(
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) ChatCompletionRequest {
|
||||||
|
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||||
|
|
||||||
|
for _, m := range messages {
|
||||||
|
switch m.Role {
|
||||||
|
case "tool_call":
|
||||||
|
message := ChatCompletionMessage{}
|
||||||
|
message.Role = "assistant"
|
||||||
|
message.Content = m.Content
|
||||||
|
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
case "tool_result":
|
||||||
|
// expand tool_result messages' results into multiple openAI messages
|
||||||
|
for _, result := range m.ToolResults {
|
||||||
|
message := ChatCompletionMessage{}
|
||||||
|
message.Role = "tool"
|
||||||
|
message.Content = result.Result
|
||||||
|
message.ToolCallID = result.ToolCallID
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
message := ChatCompletionMessage{}
|
||||||
|
message.Role = string(m.Role)
|
||||||
|
message.Content = m.Content
|
||||||
|
requestMessages = append(requestMessages, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
request := ChatCompletionRequest{
|
||||||
|
Model: params.Model,
|
||||||
|
MaxTokens: params.MaxTokens,
|
||||||
|
Temperature: params.Temperature,
|
||||||
|
Messages: requestMessages,
|
||||||
|
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.Toolbox) > 0 {
|
||||||
|
request.Tools = convertTools(params.Toolbox)
|
||||||
|
request.ToolChoice = "auto"
|
||||||
|
}
|
||||||
|
|
||||||
|
return request
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||||
|
jsonData, err := json.Marshal(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||||
|
for header, val := range c.Headers {
|
||||||
|
req.Header.Set(header, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
bytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return resp, fmt.Errorf("%v", string(bytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createChatCompletionRequest(params, messages)
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var completionResp ChatCompletionResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
choice := completionResp.Choices[0]
|
||||||
|
|
||||||
|
var content string
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content = lastMessage.Content + choice.Message.Content
|
||||||
|
} else {
|
||||||
|
content = choice.Message.Content
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls := choice.Message.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return api.NewMessageWithToolCalls(content, convertToolCallToAPI(toolCalls)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params provider.RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
output chan<- provider.Chunk,
|
||||||
|
) (*api.Message, error) {
|
||||||
|
if len(messages) == 0 {
|
||||||
|
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||||
|
}
|
||||||
|
|
||||||
|
req := createChatCompletionRequest(params, messages)
|
||||||
|
req.Stream = true
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
content := strings.Builder{}
|
||||||
|
toolCalls := []ToolCall{}
|
||||||
|
|
||||||
|
lastMessage := messages[len(messages)-1]
|
||||||
|
if lastMessage.Role.IsAssistant() {
|
||||||
|
content.WriteString(lastMessage.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
if bytes.Equal(line, []byte("[DONE]")) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
var streamResp ChatCompletionStreamResponse
|
||||||
|
err = json.Unmarshal(line, &streamResp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
delta := streamResp.Choices[0].Delta
|
||||||
|
if len(delta.ToolCalls) > 0 {
|
||||||
|
// Construct streamed tool_call arguments
|
||||||
|
for _, tc := range delta.ToolCalls {
|
||||||
|
if tc.Index == nil {
|
||||||
|
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||||
|
}
|
||||||
|
if len(toolCalls) <= *tc.Index {
|
||||||
|
toolCalls = append(toolCalls, tc)
|
||||||
|
} else {
|
||||||
|
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(delta.Content) > 0 {
|
||||||
|
output <- provider.Chunk{
|
||||||
|
Content: delta.Content,
|
||||||
|
TokenCount: 1,
|
||||||
|
}
|
||||||
|
content.WriteString(delta.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
return api.NewMessageWithToolCalls(content.String(), convertToolCallToAPI(toolCalls)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.NewMessageWithAssistant(content.String()), nil
|
||||||
|
}
|
41
pkg/provider/provider.go
Normal file
41
pkg/provider/provider.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package provider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Chunk struct {
|
||||||
|
Content string
|
||||||
|
TokenCount uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestParameters struct {
|
||||||
|
Model string
|
||||||
|
|
||||||
|
MaxTokens int
|
||||||
|
Temperature float32
|
||||||
|
TopP float32
|
||||||
|
|
||||||
|
Toolbox []api.ToolSpec
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatCompletionProvider interface {
|
||||||
|
// CreateChatCompletion generates a chat completion response to the
|
||||||
|
// provided messages.
|
||||||
|
CreateChatCompletion(
|
||||||
|
ctx context.Context,
|
||||||
|
params RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
) (*api.Message, error)
|
||||||
|
|
||||||
|
// Like CreateChageCompletion, except the response is streamed via
|
||||||
|
// the output channel.
|
||||||
|
CreateChatCompletionStream(
|
||||||
|
ctx context.Context,
|
||||||
|
params RequestParameters,
|
||||||
|
messages []api.Message,
|
||||||
|
chunks chan<- Chunk,
|
||||||
|
) (*api.Message, error)
|
||||||
|
}
|
67
pkg/tui/bubbles/confirmprompt.go
Normal file
67
pkg/tui/bubbles/confirmprompt.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package bubbles
|
||||||
|
|
||||||
|
import (
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConfirmPrompt struct {
|
||||||
|
Question string
|
||||||
|
Style lipgloss.Style
|
||||||
|
Payload interface{}
|
||||||
|
value bool
|
||||||
|
answered bool
|
||||||
|
focused bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfirmPrompt(question string, payload interface{}) ConfirmPrompt {
|
||||||
|
return ConfirmPrompt{
|
||||||
|
Question: question,
|
||||||
|
Style: lipgloss.NewStyle(),
|
||||||
|
Payload: payload,
|
||||||
|
focused: true, // focus by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgConfirmPromptAnswered struct {
|
||||||
|
Value bool
|
||||||
|
Payload interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b ConfirmPrompt) Update(msg tea.Msg) (ConfirmPrompt, tea.Cmd) {
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if !b.focused || b.answered {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
switch msg.String() {
|
||||||
|
case "y", "Y":
|
||||||
|
b.value = true
|
||||||
|
b.answered = true
|
||||||
|
b.focused = false
|
||||||
|
return b, func() tea.Msg { return MsgConfirmPromptAnswered{true, b.Payload} }
|
||||||
|
case "n", "N", "esc":
|
||||||
|
b.value = false
|
||||||
|
b.answered = true
|
||||||
|
b.focused = false
|
||||||
|
return b, func() tea.Msg { return MsgConfirmPromptAnswered{false, b.Payload} }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b ConfirmPrompt) View() string {
|
||||||
|
return b.Style.Render(b.Question) + lipgloss.NewStyle().Faint(true).Render(" (y/n)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ConfirmPrompt) Focus() {
|
||||||
|
b.focused = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *ConfirmPrompt) Blur() {
|
||||||
|
b.focused = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b ConfirmPrompt) Focused() bool {
|
||||||
|
return b.focused
|
||||||
|
}
|
260
pkg/tui/bubbles/list/list.go
Normal file
260
pkg/tui/bubbles/list/list.go
Normal file
@ -0,0 +1,260 @@
|
|||||||
|
package list
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
|
"github.com/charmbracelet/bubbles/textinput"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Option struct {
|
||||||
|
Label string
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type OptionGroup struct {
|
||||||
|
Name string
|
||||||
|
Options []Option
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
ID int
|
||||||
|
HeaderStyle lipgloss.Style
|
||||||
|
ItemStyle lipgloss.Style
|
||||||
|
SelectedStyle lipgloss.Style
|
||||||
|
ItemRender func(Option, bool) string
|
||||||
|
|
||||||
|
Width int
|
||||||
|
Height int
|
||||||
|
|
||||||
|
optionGroups []OptionGroup
|
||||||
|
selected int
|
||||||
|
filterInput textinput.Model
|
||||||
|
filteredIndices []filteredIndex
|
||||||
|
content viewport.Model
|
||||||
|
itemYOffsets []int
|
||||||
|
}
|
||||||
|
|
||||||
|
type filteredIndex struct {
|
||||||
|
groupIndex int
|
||||||
|
optionIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
type MsgOptionSelected struct {
|
||||||
|
ID int
|
||||||
|
Option Option
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(opts []Option) Model {
|
||||||
|
return NewWithGroups([]OptionGroup{{Options: opts}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWithGroups(groups []OptionGroup) Model {
|
||||||
|
ti := textinput.New()
|
||||||
|
ti.Prompt = "/"
|
||||||
|
ti.PromptStyle = lipgloss.NewStyle().Faint(true)
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
HeaderStyle: lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("12")).Padding(1, 0, 1, 1),
|
||||||
|
ItemStyle: lipgloss.NewStyle(),
|
||||||
|
SelectedStyle: lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6")),
|
||||||
|
|
||||||
|
optionGroups: groups,
|
||||||
|
selected: 0,
|
||||||
|
filterInput: ti,
|
||||||
|
filteredIndices: make([]filteredIndex, 0),
|
||||||
|
content: viewport.New(0, 0),
|
||||||
|
itemYOffsets: make([]int, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.filterItems()
|
||||||
|
m.content.SetContent(m.renderOptionsList())
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Focused() {
|
||||||
|
m.filterInput.Focused()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Focus() {
|
||||||
|
m.filterInput.Focus()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Blur() {
|
||||||
|
m.filterInput.Blur()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) filterItems() {
|
||||||
|
filterText := strings.ToLower(m.filterInput.Value())
|
||||||
|
|
||||||
|
var prevSelection *filteredIndex
|
||||||
|
if m.selected <= len(m.filteredIndices)-1 {
|
||||||
|
prevSelection = &m.filteredIndices[m.selected]
|
||||||
|
}
|
||||||
|
|
||||||
|
m.filteredIndices = make([]filteredIndex, 0)
|
||||||
|
|
||||||
|
for groupIndex, group := range m.optionGroups {
|
||||||
|
for optionIndex, option := range group.Options {
|
||||||
|
if filterText == "" ||
|
||||||
|
strings.Contains(strings.ToLower(option.Label), filterText) ||
|
||||||
|
(group.Name != "" && strings.Contains(strings.ToLower(group.Name), filterText)) {
|
||||||
|
m.filteredIndices = append(m.filteredIndices, filteredIndex{groupIndex, optionIndex})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
if len(m.filteredIndices) > 0 && prevSelection != nil {
|
||||||
|
// Preserve previous selection if possible
|
||||||
|
for i, filterIdx := range m.filteredIndices {
|
||||||
|
if prevSelection.groupIndex == filterIdx.groupIndex && prevSelection.optionIndex == filterIdx.optionIndex {
|
||||||
|
m.selected = i
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
m.selected = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
if m.filterInput.Focused() {
|
||||||
|
switch msg.String() {
|
||||||
|
case "esc":
|
||||||
|
m.filterInput.Blur()
|
||||||
|
m.filterInput.SetValue("")
|
||||||
|
m.filterItems()
|
||||||
|
m.refreshContent()
|
||||||
|
return *m, shared.KeyHandled(msg)
|
||||||
|
case "enter":
|
||||||
|
m.filterInput.Blur()
|
||||||
|
m.refreshContent()
|
||||||
|
break
|
||||||
|
case "up", "down":
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
m.filterInput, cmd = m.filterInput.Update(msg)
|
||||||
|
m.filterItems()
|
||||||
|
m.refreshContent()
|
||||||
|
return *m, cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "up", "k":
|
||||||
|
m.moveSelection(-1)
|
||||||
|
return *m, shared.KeyHandled(msg)
|
||||||
|
case "down", "j":
|
||||||
|
m.moveSelection(1)
|
||||||
|
return *m, shared.KeyHandled(msg)
|
||||||
|
case "enter":
|
||||||
|
return *m, func() tea.Msg {
|
||||||
|
idx := m.filteredIndices[m.selected]
|
||||||
|
return MsgOptionSelected{
|
||||||
|
ID: m.ID,
|
||||||
|
Option: m.optionGroups[idx.groupIndex].Options[idx.optionIndex],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "/":
|
||||||
|
m.filterInput.Focus()
|
||||||
|
return *m, textinput.Blink
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.content, cmd = m.content.Update(msg)
|
||||||
|
return *m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) refreshContent() {
|
||||||
|
m.content.SetContent(m.renderOptionsList())
|
||||||
|
m.ensureSelectedVisible()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) ensureSelectedVisible() {
|
||||||
|
if m.selected == 0 {
|
||||||
|
m.content.GotoTop()
|
||||||
|
} else if m.selected == len(m.filteredIndices)-1 {
|
||||||
|
m.content.GotoBottom()
|
||||||
|
} else {
|
||||||
|
tuiutil.ScrollIntoView(&m.content, m.itemYOffsets[m.selected], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) moveSelection(delta int) {
|
||||||
|
prev := m.selected
|
||||||
|
m.selected = min(len(m.filteredIndices)-1, max(0, m.selected+delta))
|
||||||
|
if prev != m.selected {
|
||||||
|
m.refreshContent()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) View() string {
|
||||||
|
filter := ""
|
||||||
|
if m.filterInput.Focused() {
|
||||||
|
m.filterInput.Width = m.Width
|
||||||
|
filter = m.filterInput.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
contentHeight := m.Height - tuiutil.Height(filter)
|
||||||
|
m.content.Width, m.content.Height = m.Width, contentHeight
|
||||||
|
|
||||||
|
parts := []string{m.content.View()}
|
||||||
|
if filter != "" {
|
||||||
|
parts = append(parts, filter)
|
||||||
|
}
|
||||||
|
return lipgloss.JoinVertical(lipgloss.Left, parts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) renderOptionsList() string {
|
||||||
|
yOffset := 0
|
||||||
|
lastGroupIndex := -1
|
||||||
|
m.itemYOffsets = make([]int, len(m.filteredIndices))
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, idx := range m.filteredIndices {
|
||||||
|
if idx.groupIndex != lastGroupIndex {
|
||||||
|
group := m.optionGroups[idx.groupIndex].Name
|
||||||
|
if group != "" {
|
||||||
|
headingStr := m.HeaderStyle.Render(group)
|
||||||
|
yOffset += tuiutil.Height(headingStr)
|
||||||
|
sb.WriteString(headingStr)
|
||||||
|
sb.WriteRune('\n')
|
||||||
|
}
|
||||||
|
lastGroupIndex = idx.groupIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
m.itemYOffsets[i] = yOffset
|
||||||
|
option := m.optionGroups[idx.groupIndex].Options[idx.optionIndex]
|
||||||
|
var item string
|
||||||
|
if m.ItemRender != nil {
|
||||||
|
item = m.ItemRender(option, i == m.selected)
|
||||||
|
} else {
|
||||||
|
prefix := " "
|
||||||
|
if i == m.selected {
|
||||||
|
prefix = "> "
|
||||||
|
item = m.SelectedStyle.Render(option.Label)
|
||||||
|
} else {
|
||||||
|
item = m.ItemStyle.Render(option.Label)
|
||||||
|
}
|
||||||
|
item = fmt.Sprintf("%s%s", prefix, item)
|
||||||
|
}
|
||||||
|
sb.WriteString(item)
|
||||||
|
yOffset += tuiutil.Height(item)
|
||||||
|
if i < len(m.filteredIndices)-1 {
|
||||||
|
sb.WriteRune('\n')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
281
pkg/tui/model/model.go
Normal file
281
pkg/tui/model/model.go
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AppModel struct {
|
||||||
|
Ctx *lmcli.Context
|
||||||
|
Conversations conversation.ConversationList
|
||||||
|
Conversation conversation.Conversation
|
||||||
|
Messages []conversation.Message
|
||||||
|
Model string
|
||||||
|
ProviderName string
|
||||||
|
Provider provider.ChatCompletionProvider
|
||||||
|
Agent *lmcli.Agent
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAppModel(ctx *lmcli.Context, initialConversation *conversation.Conversation) *AppModel {
|
||||||
|
app := &AppModel{
|
||||||
|
Ctx: ctx,
|
||||||
|
Model: *ctx.Config.Defaults.Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if initialConversation == nil {
|
||||||
|
app.NewConversation()
|
||||||
|
} else {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
model, provider, _, _ := ctx.GetModelProvider(*ctx.Config.Defaults.Model, "")
|
||||||
|
app.Model = model
|
||||||
|
app.ProviderName = provider
|
||||||
|
app.Agent = ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultStyle = lipgloss.NewStyle().Faint(true)
|
||||||
|
accentStyle = defaultStyle.Foreground(lipgloss.Color("6"))
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *AppModel) ActiveModel(style lipgloss.Style) string {
|
||||||
|
defaultStyle := style.Inherit(defaultStyle)
|
||||||
|
accentStyle := style.Inherit(accentStyle)
|
||||||
|
return defaultStyle.Render(a.Model) + accentStyle.Render("@") + defaultStyle.Render(a.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MessageCycleDirection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
CycleNext MessageCycleDirection = 1
|
||||||
|
CyclePrev MessageCycleDirection = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *AppModel) ClearConversation() {
|
||||||
|
m.Conversation = conversation.Conversation{}
|
||||||
|
m.Messages = []conversation.Message{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AppModel) ApplySystemPrompt() {
|
||||||
|
var system string
|
||||||
|
agent := m.Ctx.GetAgent(m.Ctx.Config.Defaults.Agent)
|
||||||
|
if agent != nil && agent.SystemPrompt != "" {
|
||||||
|
system = agent.SystemPrompt
|
||||||
|
}
|
||||||
|
if system == "" {
|
||||||
|
system = m.Ctx.DefaultSystemPrompt()
|
||||||
|
}
|
||||||
|
if system != "" {
|
||||||
|
m.Messages = conversation.ApplySystemPrompt(m.Messages, system, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AppModel) NewConversation() {
|
||||||
|
m.ClearConversation()
|
||||||
|
m.ApplySystemPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) LoadConversationMessages() ([]conversation.Message, error) {
|
||||||
|
messages, err := a.Ctx.Conversations.PathToLeaf(a.Conversation.SelectedRoot)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not load conversation messages: %v %v", a.Conversation.SelectedRoot, err)
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) GenerateConversationTitle(messages []conversation.Message) (string, error) {
|
||||||
|
return cmdutil.GenerateTitle(a.Ctx, messages)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) CloneMessage(message conversation.Message, selected bool) (*conversation.Message, error) {
|
||||||
|
msg, _, err := a.Ctx.Conversations.CloneBranch(message)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not clone message: %v", err)
|
||||||
|
}
|
||||||
|
if selected {
|
||||||
|
if msg.Parent == nil {
|
||||||
|
msg.Conversation.SelectedRoot = msg
|
||||||
|
err = a.Ctx.Conversations.UpdateConversation(msg.Conversation)
|
||||||
|
} else {
|
||||||
|
msg.Parent.SelectedReply = msg
|
||||||
|
err = a.Ctx.Conversations.UpdateMessage(msg.Parent)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not update selected message: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) UpdateMessageContent(message *conversation.Message) error {
|
||||||
|
return a.Ctx.Conversations.UpdateMessage(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cycleSelectedMessage(selected *conversation.Message, choices []conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
|
currentIndex := -1
|
||||||
|
for i, reply := range choices {
|
||||||
|
if reply.ID == selected.ID {
|
||||||
|
currentIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentIndex < 0 {
|
||||||
|
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var next int
|
||||||
|
if dir == CyclePrev {
|
||||||
|
next = (currentIndex - 1 + len(choices)) % len(choices)
|
||||||
|
} else {
|
||||||
|
next = (currentIndex + 1) % len(choices)
|
||||||
|
}
|
||||||
|
return &choices[next], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) CycleSelectedRoot(conv *conversation.Conversation, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
|
if len(conv.RootMessages) < 2 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, conv.RootMessages, dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conv.SelectedRoot = nextRoot
|
||||||
|
err = a.Ctx.Conversations.UpdateConversation(conv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not update conversation SelectedRoot: %v", err)
|
||||||
|
}
|
||||||
|
return nextRoot, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) CycleSelectedReply(message *conversation.Message, dir MessageCycleDirection) (*conversation.Message, error) {
|
||||||
|
if len(message.Replies) < 2 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
message.SelectedReply = nextReply
|
||||||
|
err = a.Ctx.Conversations.UpdateMessage(message)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not update message SelectedReply: %v", err)
|
||||||
|
}
|
||||||
|
return nextReply, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) PersistMessages() ([]conversation.Message, error) {
|
||||||
|
messages := make([]conversation.Message, len(a.Messages))
|
||||||
|
for i, m := range a.Messages {
|
||||||
|
if i == 0 && m.ID == 0 {
|
||||||
|
m.Conversation = &a.Conversation
|
||||||
|
m, err := a.Ctx.Conversations.SaveMessage(m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not create first message %d: %v", a.Messages[i].ID, err)
|
||||||
|
}
|
||||||
|
messages[i] = *m
|
||||||
|
// let's set the conversation root message(s), as this is the first message
|
||||||
|
m.Conversation.RootMessages = []conversation.Message{*m}
|
||||||
|
m.Conversation.SelectedRoot = &m.Conversation.RootMessages[0]
|
||||||
|
a.Ctx.Conversations.UpdateConversation(m.Conversation)
|
||||||
|
} else if m.ID > 0 {
|
||||||
|
// Existing message, update it
|
||||||
|
err := a.Ctx.Conversations.UpdateMessage(&m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not update message %d: %v", a.Messages[i].ID, err)
|
||||||
|
}
|
||||||
|
messages[i] = m
|
||||||
|
} else if i > 0 {
|
||||||
|
// New message, reply to previous
|
||||||
|
replies, err := a.Ctx.Conversations.Reply(&messages[i-1], m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Could not reply with new message: %v", err)
|
||||||
|
}
|
||||||
|
messages[i] = replies[0]
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("No messages to reply to (this is a bug)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) PersistConversation() (conversation.Conversation, error) {
|
||||||
|
conv := a.Conversation
|
||||||
|
var err error
|
||||||
|
if a.Conversation.ID > 0 {
|
||||||
|
err = a.Ctx.Conversations.UpdateConversation(&conv)
|
||||||
|
} else {
|
||||||
|
c, e := a.Ctx.Conversations.CreateConversation("")
|
||||||
|
err = e
|
||||||
|
if e == nil && c != nil {
|
||||||
|
conv = *c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conv, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) ExecuteToolCalls(toolCalls []api.ToolCall) ([]api.ToolResult, error) {
|
||||||
|
agent := a.Ctx.GetAgent(a.Ctx.Config.Defaults.Agent)
|
||||||
|
if agent == nil {
|
||||||
|
return nil, fmt.Errorf("Attempted to execute tool calls with no agent configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
return agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AppModel) Prompt(
|
||||||
|
messages []conversation.Message,
|
||||||
|
chatReplyChunks chan provider.Chunk,
|
||||||
|
stopSignal chan struct{},
|
||||||
|
) (*conversation.Message, error) {
|
||||||
|
model, _, p, err := a.Ctx.GetModelProvider(a.Model, a.ProviderName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := provider.RequestParameters{
|
||||||
|
Model: model,
|
||||||
|
MaxTokens: *a.Ctx.Config.Defaults.MaxTokens,
|
||||||
|
Temperature: *a.Ctx.Config.Defaults.Temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
if a.Agent != nil {
|
||||||
|
params.Toolbox = a.Agent.Toolbox
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-stopSignal:
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
msg, err := p.CreateChatCompletionStream(
|
||||||
|
ctx, params, conversation.MessagesToAPI(messages), chatReplyChunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if msg != nil {
|
||||||
|
msg := conversation.MessageFromAPI(*msg)
|
||||||
|
msg.Metadata.GenerationProvider = &a.ProviderName
|
||||||
|
msg.Metadata.GenerationModel = &a.Model
|
||||||
|
return &msg, err
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
66
pkg/tui/shared/shared.go
Normal file
66
pkg/tui/shared/shared.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package shared
|
||||||
|
|
||||||
|
import (
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
// An analogue to tea.Model with support for checking if the model has been
|
||||||
|
// initialized before
|
||||||
|
type ViewModel interface {
|
||||||
|
Init() tea.Cmd
|
||||||
|
Update(tea.Msg) (ViewModel, tea.Cmd)
|
||||||
|
|
||||||
|
// View methods
|
||||||
|
Header(width int) string
|
||||||
|
// Render the view's main content into a container of the given dimensions
|
||||||
|
Content(width, height int) string
|
||||||
|
Footer(width int) string
|
||||||
|
}
|
||||||
|
|
||||||
|
type View int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ViewChat View = iota
|
||||||
|
ViewConversations
|
||||||
|
ViewSettings
|
||||||
|
//StateHelp
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// send to change the current state
|
||||||
|
MsgViewChange View
|
||||||
|
// sent to a state when it is entered, with the view we're leaving
|
||||||
|
MsgViewEnter View
|
||||||
|
// sent when a recoverable error occurs (displayed to user)
|
||||||
|
MsgError struct { Err error }
|
||||||
|
// sent when the view has handled a key input
|
||||||
|
MsgKeyHandled tea.KeyMsg
|
||||||
|
)
|
||||||
|
|
||||||
|
func ViewEnter(from View) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return MsgViewEnter(from)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChangeView(view View) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return MsgViewChange(view)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func KeyHandled(key tea.KeyMsg) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return MsgKeyHandled(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WrapError(err error) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return MsgError{ Err: err }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AsMsgError(err error) MsgError {
|
||||||
|
return MsgError{ Err: err }
|
||||||
|
}
|
8
pkg/tui/styles/styles.go
Normal file
8
pkg/tui/styles/styles.go
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
package styles
|
||||||
|
|
||||||
|
import "github.com/charmbracelet/lipgloss"
|
||||||
|
|
||||||
|
var Header = lipgloss.NewStyle().
|
||||||
|
PaddingLeft(1).
|
||||||
|
PaddingRight(1).
|
||||||
|
Background(lipgloss.Color("0"))
|
922
pkg/tui/tui.go
922
pkg/tui/tui.go
@ -1,845 +1,163 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
// The terminal UI for lmcli, launched from the `lmcli chat` command
|
|
||||||
// TODO:
|
|
||||||
// - conversation list view
|
|
||||||
// - change model
|
|
||||||
// - rename conversation
|
|
||||||
// - set system prompt
|
|
||||||
// - system prompt library?
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
"github.com/charmbracelet/bubbles/spinner"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
"github.com/charmbracelet/bubbles/textarea"
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
"github.com/charmbracelet/bubbles/viewport"
|
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/views/settings"
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/muesli/reflow/wordwrap"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type focusState int
|
type Model struct {
|
||||||
|
App *model.AppModel
|
||||||
|
|
||||||
const (
|
// window size
|
||||||
focusInput focusState = iota
|
|
||||||
focusMessages
|
|
||||||
)
|
|
||||||
|
|
||||||
type editorTarget int
|
|
||||||
|
|
||||||
const (
|
|
||||||
input editorTarget = iota
|
|
||||||
selectedMessage
|
|
||||||
)
|
|
||||||
|
|
||||||
type model struct {
|
|
||||||
width int
|
width int
|
||||||
height int
|
height int
|
||||||
|
|
||||||
ctx *lmcli.Context
|
// errors to display
|
||||||
convShortname string
|
// TODO: allow dismissing errors
|
||||||
|
errs []error
|
||||||
|
|
||||||
// application state
|
activeView shared.View
|
||||||
conversation *models.Conversation
|
views map[shared.View]shared.ViewModel
|
||||||
messages []models.Message
|
|
||||||
waitingForReply bool
|
|
||||||
editorTarget editorTarget
|
|
||||||
stopSignal chan interface{}
|
|
||||||
replyChan chan models.Message
|
|
||||||
replyChunkChan chan string
|
|
||||||
persistence bool // whether we will save new messages in the conversation
|
|
||||||
err error
|
|
||||||
|
|
||||||
// ui state
|
|
||||||
focus focusState
|
|
||||||
wrap bool // whether message content is wrapped to viewport width
|
|
||||||
status string // a general status message
|
|
||||||
highlightCache []string // a cache of syntax highlighted message content
|
|
||||||
messageOffsets []int
|
|
||||||
selectedMessage int
|
|
||||||
|
|
||||||
// ui elements
|
|
||||||
content viewport.Model
|
|
||||||
input textarea.Model
|
|
||||||
spinner spinner.Model
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type message struct {
|
func initialModel(ctx *lmcli.Context, opts LaunchOptions) *Model {
|
||||||
role string
|
app := model.NewAppModel(ctx, opts.InitialConversation)
|
||||||
content string
|
|
||||||
}
|
|
||||||
|
|
||||||
// custom tea.Msg types
|
m := Model{
|
||||||
type (
|
App: app,
|
||||||
// sent on each chunk received from LLM
|
activeView: opts.InitialView,
|
||||||
msgResponseChunk string
|
views: map[shared.View]shared.ViewModel{
|
||||||
// sent when response is finished being received
|
shared.ViewChat: chat.Chat(app),
|
||||||
msgResponseEnd string
|
shared.ViewConversations: conversations.Conversations(app),
|
||||||
// a special case of msgError that stops the response waiting animation
|
shared.ViewSettings: settings.Settings(app),
|
||||||
msgResponseError error
|
|
||||||
// sent on each completed reply
|
|
||||||
msgAssistantReply models.Message
|
|
||||||
// sent when a conversation is (re)loaded
|
|
||||||
msgConversationLoaded *models.Conversation
|
|
||||||
// sent when a new conversation title is set
|
|
||||||
msgConversationTitleChanged string
|
|
||||||
// send when a conversation's messages are laoded
|
|
||||||
msgMessagesLoaded []models.Message
|
|
||||||
// sent when an error occurs
|
|
||||||
msgError error
|
|
||||||
)
|
|
||||||
|
|
||||||
// styles
|
|
||||||
var (
|
|
||||||
userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10"))
|
|
||||||
assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12"))
|
|
||||||
messageStyle = lipgloss.NewStyle().PaddingLeft(2).PaddingRight(2)
|
|
||||||
headerStyle = lipgloss.NewStyle().
|
|
||||||
Background(lipgloss.Color("0"))
|
|
||||||
conversationStyle = lipgloss.NewStyle().
|
|
||||||
MarginTop(1).
|
|
||||||
MarginBottom(1)
|
|
||||||
footerStyle = lipgloss.NewStyle().
|
|
||||||
BorderTop(true).
|
|
||||||
BorderStyle(lipgloss.NormalBorder())
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m model) Init() tea.Cmd {
|
|
||||||
return tea.Batch(
|
|
||||||
textarea.Blink,
|
|
||||||
m.spinner.Tick,
|
|
||||||
m.loadConversation(m.convShortname),
|
|
||||||
m.waitForChunk(),
|
|
||||||
m.waitForReply(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wrapError(err error) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
return msgError(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|
||||||
var cmds []tea.Cmd
|
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
|
||||||
case msgTempfileEditorClosed:
|
|
||||||
contents := string(msg)
|
|
||||||
switch m.editorTarget {
|
|
||||||
case input:
|
|
||||||
m.input.SetValue(contents)
|
|
||||||
case selectedMessage:
|
|
||||||
m.setMessageContents(m.selectedMessage, contents)
|
|
||||||
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
|
|
||||||
// update persisted message
|
|
||||||
err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
|
|
||||||
if err != nil {
|
|
||||||
cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.updateContent()
|
|
||||||
}
|
|
||||||
case tea.KeyMsg:
|
|
||||||
switch msg.String() {
|
|
||||||
case "ctrl+c":
|
|
||||||
if m.waitingForReply {
|
|
||||||
m.stopSignal <- ""
|
|
||||||
} else {
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
case "ctrl+p":
|
|
||||||
m.persistence = !m.persistence
|
|
||||||
case "ctrl+w":
|
|
||||||
m.wrap = !m.wrap
|
|
||||||
m.updateContent()
|
|
||||||
case "q":
|
|
||||||
if m.focus != focusInput {
|
|
||||||
return m, tea.Quit
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
var inputHandled tea.Cmd
|
|
||||||
switch m.focus {
|
|
||||||
case focusInput:
|
|
||||||
inputHandled = m.handleInputKey(msg)
|
|
||||||
case focusMessages:
|
|
||||||
inputHandled = m.handleMessagesKey(msg)
|
|
||||||
}
|
|
||||||
if inputHandled != nil {
|
|
||||||
return m, inputHandled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case tea.WindowSizeMsg:
|
|
||||||
m.width = msg.Width
|
|
||||||
m.height = msg.Height
|
|
||||||
m.content.Width = msg.Width
|
|
||||||
m.content.Height = msg.Height - m.getFixedComponentHeight()
|
|
||||||
m.input.SetWidth(msg.Width - 1)
|
|
||||||
m.updateContent()
|
|
||||||
case msgConversationLoaded:
|
|
||||||
m.conversation = (*models.Conversation)(msg)
|
|
||||||
cmds = append(cmds, m.loadMessages(m.conversation))
|
|
||||||
case msgMessagesLoaded:
|
|
||||||
m.setMessages(msg)
|
|
||||||
m.updateContent()
|
|
||||||
case msgResponseChunk:
|
|
||||||
chunk := string(msg)
|
|
||||||
last := len(m.messages) - 1
|
|
||||||
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
|
|
||||||
m.setMessageContents(last, m.messages[last].Content+chunk)
|
|
||||||
} else {
|
|
||||||
m.addMessage(models.Message{
|
|
||||||
Role: models.MessageRoleAssistant,
|
|
||||||
Content: chunk,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
m.updateContent()
|
|
||||||
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
|
||||||
case msgAssistantReply:
|
|
||||||
// the last reply that was being worked on is finished
|
|
||||||
reply := models.Message(msg)
|
|
||||||
last := len(m.messages) - 1
|
|
||||||
if last < 0 {
|
|
||||||
panic("Unexpected empty messages handling msgReply")
|
|
||||||
}
|
|
||||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
|
||||||
if m.messages[last].Role == models.MessageRoleAssistant {
|
|
||||||
// the last message was an assistant message, so this is a continuation
|
|
||||||
if reply.Role == models.MessageRoleToolCall {
|
|
||||||
// update last message rrole to tool call
|
|
||||||
m.messages[last].Role = models.MessageRoleToolCall
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
m.addMessage(reply)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.persistence {
|
|
||||||
var err error
|
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
cmds = append(cmds, wrapError(err))
|
|
||||||
} else {
|
|
||||||
cmds = append(cmds, m.persistConversation())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.conversation.Title == "" {
|
|
||||||
cmds = append(cmds, m.generateConversationTitle())
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateContent()
|
|
||||||
cmds = append(cmds, m.waitForReply())
|
|
||||||
case msgResponseEnd:
|
|
||||||
m.waitingForReply = false
|
|
||||||
last := len(m.messages) - 1
|
|
||||||
if last < 0 {
|
|
||||||
panic("Unexpected empty messages handling msgResponseEnd")
|
|
||||||
}
|
|
||||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
|
||||||
m.updateContent()
|
|
||||||
m.status = "Press ctrl+s to send"
|
|
||||||
case msgResponseError:
|
|
||||||
m.waitingForReply = false
|
|
||||||
m.status = "Press ctrl+s to send"
|
|
||||||
m.err = error(msg)
|
|
||||||
case msgConversationTitleChanged:
|
|
||||||
title := string(msg)
|
|
||||||
m.conversation.Title = title
|
|
||||||
if m.persistence {
|
|
||||||
err := m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
if err != nil {
|
|
||||||
cmds = append(cmds, wrapError(err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case msgError:
|
|
||||||
m.err = error(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cmd tea.Cmd
|
|
||||||
m.spinner, cmd = m.spinner.Update(msg)
|
|
||||||
if cmd != nil {
|
|
||||||
cmds = append(cmds, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
inputCaptured := false
|
|
||||||
m.input, cmd = m.input.Update(msg)
|
|
||||||
if cmd != nil {
|
|
||||||
inputCaptured = true
|
|
||||||
cmds = append(cmds, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !inputCaptured {
|
|
||||||
m.content, cmd = m.content.Update(msg)
|
|
||||||
if cmd != nil {
|
|
||||||
cmds = append(cmds, cmd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m, tea.Batch(cmds...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m model) View() string {
|
|
||||||
if m.width == 0 {
|
|
||||||
// this is the case upon initial startup, but it's also a safe bet that
|
|
||||||
// we can just skip rendering if the terminal is really 0 width...
|
|
||||||
// without this, the m.*View() functions may crash
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
sections := make([]string, 0, 6)
|
|
||||||
sections = append(sections, m.headerView())
|
|
||||||
sections = append(sections, m.contentView())
|
|
||||||
error := m.errorView()
|
|
||||||
if error != "" {
|
|
||||||
sections = append(sections, error)
|
|
||||||
}
|
|
||||||
sections = append(sections, m.inputView())
|
|
||||||
sections = append(sections, m.footerView())
|
|
||||||
|
|
||||||
return lipgloss.JoinVertical(
|
|
||||||
lipgloss.Left,
|
|
||||||
sections...,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns the total height of "fixed" components, which are those which don't
|
|
||||||
// change height dependent on window size.
|
|
||||||
func (m *model) getFixedComponentHeight() int {
|
|
||||||
h := 0
|
|
||||||
h += m.input.Height()
|
|
||||||
h += lipgloss.Height(m.headerView())
|
|
||||||
h += lipgloss.Height(m.footerView())
|
|
||||||
errorView := m.errorView()
|
|
||||||
if errorView != "" {
|
|
||||||
h += lipgloss.Height(errorView)
|
|
||||||
}
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) headerView() string {
|
|
||||||
titleStyle := lipgloss.NewStyle().
|
|
||||||
PaddingLeft(1).
|
|
||||||
PaddingRight(1).
|
|
||||||
Bold(true)
|
|
||||||
var title string
|
|
||||||
if m.conversation != nil && m.conversation.Title != "" {
|
|
||||||
title = m.conversation.Title
|
|
||||||
} else {
|
|
||||||
title = "Untitled"
|
|
||||||
}
|
|
||||||
part := titleStyle.Render(title)
|
|
||||||
|
|
||||||
return headerStyle.Width(m.width).Render(part)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) contentView() string {
|
|
||||||
return m.content.View()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) errorView() string {
|
|
||||||
if m.err == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return lipgloss.NewStyle().
|
|
||||||
Width(m.width).
|
|
||||||
AlignHorizontal(lipgloss.Center).
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color("1")).
|
|
||||||
Render(fmt.Sprintf("%s", m.err))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) inputView() string {
|
|
||||||
return m.input.View()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) footerView() string {
|
|
||||||
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
|
|
||||||
segmentSeparator := "|"
|
|
||||||
|
|
||||||
savingStyle := segmentStyle.Copy().Bold(true)
|
|
||||||
saving := ""
|
|
||||||
if m.persistence {
|
|
||||||
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
|
|
||||||
} else {
|
|
||||||
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
|
|
||||||
}
|
|
||||||
|
|
||||||
status := m.status
|
|
||||||
if m.waitingForReply {
|
|
||||||
status += m.spinner.View()
|
|
||||||
}
|
|
||||||
|
|
||||||
leftSegments := []string{
|
|
||||||
saving,
|
|
||||||
segmentStyle.Render(status),
|
|
||||||
}
|
|
||||||
rightSegments := []string{
|
|
||||||
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
|
|
||||||
}
|
|
||||||
|
|
||||||
left := strings.Join(leftSegments, segmentSeparator)
|
|
||||||
right := strings.Join(rightSegments, segmentSeparator)
|
|
||||||
|
|
||||||
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
|
|
||||||
remaining := m.width - totalWidth
|
|
||||||
|
|
||||||
var padding string
|
|
||||||
if remaining > 0 {
|
|
||||||
padding = strings.Repeat(" ", remaining)
|
|
||||||
}
|
|
||||||
|
|
||||||
footer := left + padding + right
|
|
||||||
if remaining < 0 {
|
|
||||||
ellipses := "... "
|
|
||||||
// this doesn't work very well, due to trying to trim a string with
|
|
||||||
// ansii chars already in it
|
|
||||||
footer = footer[:(len(footer)+remaining)-len(ellipses)-3] + ellipses
|
|
||||||
}
|
|
||||||
return footerStyle.Width(m.width).Render(footer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func initialModel(ctx *lmcli.Context, convShortname string) model {
|
|
||||||
m := model{
|
|
||||||
ctx: ctx,
|
|
||||||
convShortname: convShortname,
|
|
||||||
conversation: &models.Conversation{},
|
|
||||||
persistence: true,
|
|
||||||
|
|
||||||
stopSignal: make(chan interface{}),
|
|
||||||
replyChan: make(chan models.Message),
|
|
||||||
replyChunkChan: make(chan string),
|
|
||||||
|
|
||||||
wrap: true,
|
|
||||||
selectedMessage: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
m.content = viewport.New(0, 0)
|
|
||||||
|
|
||||||
m.input = textarea.New()
|
|
||||||
m.input.CharLimit = 0
|
|
||||||
m.input.Placeholder = "Enter a message"
|
|
||||||
|
|
||||||
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
|
||||||
m.input.ShowLineNumbers = false
|
|
||||||
m.input.SetHeight(4)
|
|
||||||
m.input.Focus()
|
|
||||||
|
|
||||||
m.spinner = spinner.New(spinner.WithSpinner(
|
|
||||||
spinner.Spinner{
|
|
||||||
Frames: []string{
|
|
||||||
". ",
|
|
||||||
".. ",
|
|
||||||
"...",
|
|
||||||
".. ",
|
|
||||||
". ",
|
|
||||||
" ",
|
|
||||||
},
|
|
||||||
FPS: time.Second / 3,
|
|
||||||
},
|
},
|
||||||
))
|
}
|
||||||
|
|
||||||
m.waitingForReply = false
|
return &m
|
||||||
m.status = "Press ctrl+s to send"
|
|
||||||
return m
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fraction is the fraction of the total screen height into view the offset
|
func (m *Model) Init() tea.Cmd {
|
||||||
// should be scrolled into view. 0.5 = items will be snapped to middle of
|
var cmds []tea.Cmd
|
||||||
// view
|
for _, v := range m.views {
|
||||||
func scrollIntoView(vp *viewport.Model, offset int, fraction float32) {
|
// Init views
|
||||||
currentOffset := vp.YOffset
|
cmds = append(cmds, v.Init())
|
||||||
if offset >= currentOffset && offset < currentOffset+vp.Height {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
distance := currentOffset - offset
|
|
||||||
if distance < 0 {
|
|
||||||
// we should scroll down until it just comes into view
|
|
||||||
vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1)
|
|
||||||
} else {
|
|
||||||
// we should scroll up
|
|
||||||
vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction))
|
|
||||||
}
|
}
|
||||||
|
cmds = append(cmds, func() tea.Msg {
|
||||||
|
// Initial view change
|
||||||
|
return shared.MsgViewChange(m.activeView)
|
||||||
|
})
|
||||||
|
return tea.Batch(cmds...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
|
func (m *Model) handleGlobalInput(msg tea.KeyMsg) tea.Cmd {
|
||||||
switch msg.String() {
|
view, cmd := m.views[m.activeView].Update(msg)
|
||||||
case "tab":
|
m.views[m.activeView] = view
|
||||||
m.focus = focusInput
|
if cmd != nil {
|
||||||
m.updateContent()
|
|
||||||
m.input.Focus()
|
|
||||||
case "e":
|
|
||||||
message := m.messages[m.selectedMessage]
|
|
||||||
cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
|
|
||||||
m.editorTarget = selectedMessage
|
|
||||||
return cmd
|
|
||||||
case "ctrl+k":
|
|
||||||
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
|
|
||||||
m.selectedMessage--
|
|
||||||
m.updateContent()
|
|
||||||
offset := m.messageOffsets[m.selectedMessage]
|
|
||||||
scrollIntoView(&m.content, offset, 0.1)
|
|
||||||
}
|
|
||||||
case "ctrl+j":
|
|
||||||
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
|
|
||||||
m.selectedMessage++
|
|
||||||
m.updateContent()
|
|
||||||
offset := m.messageOffsets[m.selectedMessage]
|
|
||||||
scrollIntoView(&m.content, offset, 0.1)
|
|
||||||
}
|
|
||||||
case "ctrl+r":
|
|
||||||
// resubmit the conversation with all messages up until and including
|
|
||||||
// the selected message
|
|
||||||
if len(m.messages) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.messages = m.messages[:m.selectedMessage+1]
|
|
||||||
m.highlightCache = m.highlightCache[:m.selectedMessage+1]
|
|
||||||
m.updateContent()
|
|
||||||
m.content.GotoBottom()
|
|
||||||
return m.promptLLM()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
|
||||||
switch msg.String() {
|
|
||||||
case "esc":
|
|
||||||
m.focus = focusMessages
|
|
||||||
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
|
|
||||||
m.selectedMessage = len(m.messages) - 1
|
|
||||||
}
|
|
||||||
m.updateContent()
|
|
||||||
m.input.Blur()
|
|
||||||
case "ctrl+s":
|
|
||||||
userInput := strings.TrimSpace(m.input.Value())
|
|
||||||
if strings.TrimSpace(userInput) == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
|
|
||||||
return wrapError(fmt.Errorf("Can't reply to a user message"))
|
|
||||||
}
|
|
||||||
|
|
||||||
reply := models.Message{
|
|
||||||
Role: models.MessageRoleUser,
|
|
||||||
Content: userInput,
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.persistence {
|
|
||||||
var err error
|
|
||||||
if m.conversation.ID == 0 {
|
|
||||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensure all messages up to the one we're about to add are
|
|
||||||
// persistent
|
|
||||||
cmd := m.persistConversation()
|
|
||||||
if cmd != nil {
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
// persist our new message, returning with any possible errors
|
|
||||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
reply = *savedReply
|
|
||||||
}
|
|
||||||
|
|
||||||
m.input.SetValue("")
|
|
||||||
m.addMessage(reply)
|
|
||||||
|
|
||||||
m.updateContent()
|
|
||||||
m.content.GotoBottom()
|
|
||||||
return m.promptLLM()
|
|
||||||
case "ctrl+e":
|
|
||||||
cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
|
|
||||||
m.editorTarget = input
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) loadConversation(shortname string) tea.Cmd {
|
switch msg.String() {
|
||||||
return func() tea.Msg {
|
case "ctrl+c", "ctrl+q":
|
||||||
if shortname == "" {
|
return tea.Quit
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c, err := m.ctx.Store.ConversationByShortName(shortname)
|
|
||||||
if err != nil {
|
|
||||||
return msgError(fmt.Errorf("Could not lookup conversation: %v", err))
|
|
||||||
}
|
|
||||||
if c.ID == 0 {
|
|
||||||
return msgError(fmt.Errorf("Conversation not found: %s", shortname))
|
|
||||||
}
|
|
||||||
return msgConversationLoaded(c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) loadMessages(c *models.Conversation) tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
messages, err := m.ctx.Store.Messages(c)
|
|
||||||
if err != nil {
|
|
||||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
|
||||||
}
|
|
||||||
return msgMessagesLoaded(messages)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) waitForReply() tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
return msgAssistantReply(<-m.replyChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) waitForChunk() tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
return msgResponseChunk(<-m.replyChunkChan)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) generateConversationTitle() tea.Cmd {
|
|
||||||
return func() tea.Msg {
|
|
||||||
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
|
||||||
if err != nil {
|
|
||||||
return msgError(err)
|
|
||||||
}
|
|
||||||
return msgConversationTitleChanged(title)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) promptLLM() tea.Cmd {
|
|
||||||
m.waitingForReply = true
|
|
||||||
m.status = "Press ctrl+c to cancel"
|
|
||||||
|
|
||||||
return func() tea.Msg {
|
|
||||||
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
|
||||||
if err != nil {
|
|
||||||
return msgError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestParams := models.RequestParameters{
|
|
||||||
Model: *m.ctx.Config.Defaults.Model,
|
|
||||||
MaxTokens: *m.ctx.Config.Defaults.MaxTokens,
|
|
||||||
Temperature: *m.ctx.Config.Defaults.Temperature,
|
|
||||||
ToolBag: m.ctx.EnabledTools,
|
|
||||||
}
|
|
||||||
|
|
||||||
replyHandler := func(msg models.Message) {
|
|
||||||
m.replyChan <- msg
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
canceled := false
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-m.stopSignal:
|
|
||||||
canceled = true
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
resp, err := completionProvider.CreateChatCompletionStream(
|
|
||||||
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil && !canceled {
|
|
||||||
return msgResponseError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return msgResponseEnd(resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) persistConversation() tea.Cmd {
|
|
||||||
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
existingById := make(map[uint]*models.Message, len(existingMessages))
|
|
||||||
for _, msg := range existingMessages {
|
|
||||||
existingById[msg.ID] = &msg
|
|
||||||
}
|
|
||||||
|
|
||||||
currentById := make(map[uint]*models.Message, len(m.messages))
|
|
||||||
for _, msg := range m.messages {
|
|
||||||
currentById[msg.ID] = &msg
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, msg := range existingMessages {
|
|
||||||
_, ok := currentById[msg.ID]
|
|
||||||
if !ok {
|
|
||||||
err := m.ctx.Store.DeleteMessage(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, msg := range m.messages {
|
|
||||||
if msg.ID > 0 {
|
|
||||||
exist, ok := existingById[msg.ID]
|
|
||||||
if ok {
|
|
||||||
if msg.Content == exist.Content {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// update message when contents don't match that of store
|
|
||||||
err := m.ctx.Store.UpdateMessage(&msg)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// this would be quite odd... and I'm not sure how to handle
|
|
||||||
// it at the time of writing this
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
m.setMessage(i, *newMessage)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) setMessages(messages []models.Message) {
|
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
m.messages = messages
|
switch msg := msg.(type) {
|
||||||
m.highlightCache = make([]string, len(messages))
|
case tea.WindowSizeMsg:
|
||||||
for i, msg := range m.messages {
|
m.width, m.height = msg.Width, msg.Height
|
||||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
case tea.KeyMsg:
|
||||||
m.highlightCache[i] = highlighted
|
cmd := m.handleGlobalInput(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
case shared.MsgViewChange:
|
||||||
|
currView := m.activeView
|
||||||
|
m.activeView = shared.View(msg)
|
||||||
|
return m, tea.Batch(tea.WindowSize(), shared.ViewEnter(currView))
|
||||||
|
case shared.MsgError:
|
||||||
|
m.errs = append(m.errs, msg.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
view, cmd := m.views[m.activeView].Update(msg)
|
||||||
|
m.views[m.activeView] = view
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) View() string {
|
||||||
|
if m.width == 0 || m.height == 0 {
|
||||||
|
// we're dimensionless!
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
header := m.views[m.activeView].Header(m.width)
|
||||||
|
footer := m.views[m.activeView].Footer(m.width)
|
||||||
|
fixedUIHeight := tuiutil.Height(header) + tuiutil.Height(footer)
|
||||||
|
|
||||||
|
errBanners := make([]string, len(m.errs))
|
||||||
|
for idx, err := range m.errs {
|
||||||
|
errBanners[idx] = tuiutil.ErrorBanner(err, m.width)
|
||||||
|
fixedUIHeight += tuiutil.Height(errBanners[idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
content := m.views[m.activeView].Content(m.width, m.height-fixedUIHeight)
|
||||||
|
|
||||||
|
sections := make([]string, 0, 4)
|
||||||
|
if header != "" {
|
||||||
|
sections = append(sections, header)
|
||||||
|
}
|
||||||
|
if content != "" {
|
||||||
|
sections = append(sections, content)
|
||||||
|
}
|
||||||
|
if footer != "" {
|
||||||
|
sections = append(sections, footer)
|
||||||
|
}
|
||||||
|
for _, errBanner := range errBanners {
|
||||||
|
sections = append(sections, errBanner)
|
||||||
|
}
|
||||||
|
return lipgloss.JoinVertical(lipgloss.Left, sections...)
|
||||||
|
}
|
||||||
|
|
||||||
|
type LaunchOptions struct {
|
||||||
|
InitialConversation *conversation.Conversation
|
||||||
|
InitialView shared.View
|
||||||
|
}
|
||||||
|
|
||||||
|
type LaunchOption func(*LaunchOptions)
|
||||||
|
|
||||||
|
func WithInitialConversation(conv *conversation.Conversation) LaunchOption {
|
||||||
|
return func(opts *LaunchOptions) {
|
||||||
|
opts.InitialConversation = conv
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *model) setMessage(i int, msg models.Message) {
|
func WithInitialView(view shared.View) LaunchOption {
|
||||||
if i >= len(m.messages) {
|
return func(opts *LaunchOptions) {
|
||||||
panic("i out of range")
|
opts.InitialView = view
|
||||||
}
|
|
||||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
|
||||||
m.messages[i] = msg
|
|
||||||
m.highlightCache[i] = highlighted
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) addMessage(msg models.Message) {
|
|
||||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
|
||||||
m.messages = append(m.messages, msg)
|
|
||||||
m.highlightCache = append(m.highlightCache, highlighted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) setMessageContents(i int, content string) {
|
|
||||||
if i >= len(m.messages) {
|
|
||||||
panic("i out of range")
|
|
||||||
}
|
|
||||||
highlighted, _ := m.ctx.Chroma.HighlightS(content)
|
|
||||||
m.messages[i].Content = content
|
|
||||||
m.highlightCache[i] = highlighted
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) updateContent() {
|
|
||||||
atBottom := m.content.AtBottom()
|
|
||||||
m.content.SetContent(m.conversationView())
|
|
||||||
if atBottom {
|
|
||||||
// if we were at bottom before the update, scroll with the output
|
|
||||||
m.content.GotoBottom()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// render the conversation into a string
|
func Launch(ctx *lmcli.Context, options ...LaunchOption) error {
|
||||||
func (m *model) conversationView() string {
|
opts := &LaunchOptions{
|
||||||
sb := strings.Builder{}
|
InitialView: shared.ViewChat,
|
||||||
msgCnt := len(m.messages)
|
}
|
||||||
|
for _, opt := range options {
|
||||||
m.messageOffsets = make([]int, len(m.messages))
|
opt(opts)
|
||||||
lineCnt := conversationStyle.GetMarginTop()
|
|
||||||
for i, message := range m.messages {
|
|
||||||
m.messageOffsets[i] = lineCnt
|
|
||||||
|
|
||||||
icon := "⚙️"
|
|
||||||
friendly := message.Role.FriendlyRole()
|
|
||||||
style := lipgloss.NewStyle().Bold(true).Faint(true)
|
|
||||||
|
|
||||||
switch message.Role {
|
|
||||||
case models.MessageRoleUser:
|
|
||||||
icon = ""
|
|
||||||
style = userStyle
|
|
||||||
case models.MessageRoleAssistant:
|
|
||||||
icon = ""
|
|
||||||
style = assistantStyle
|
|
||||||
case models.MessageRoleToolCall, models.MessageRoleToolResult:
|
|
||||||
icon = "🔧"
|
|
||||||
}
|
|
||||||
|
|
||||||
// write message heading with space for content
|
|
||||||
user := style.Render(icon + friendly)
|
|
||||||
|
|
||||||
var prefix string
|
|
||||||
var suffix string
|
|
||||||
|
|
||||||
faint := lipgloss.NewStyle().Faint(true)
|
|
||||||
if m.focus == focusMessages {
|
|
||||||
if i == m.selectedMessage {
|
|
||||||
prefix = "> "
|
|
||||||
}
|
|
||||||
suffix += faint.Render(fmt.Sprintf(" (%d/%d)", i+1, msgCnt))
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.ID == 0 {
|
|
||||||
suffix += faint.Render(" (not saved)")
|
|
||||||
}
|
|
||||||
|
|
||||||
header := lipgloss.NewStyle().PaddingLeft(1).Render(prefix + user + suffix)
|
|
||||||
sb.WriteString(header)
|
|
||||||
lineCnt += lipgloss.Height(header)
|
|
||||||
|
|
||||||
// TODO: special rendering for tool calls/results?
|
|
||||||
if message.Content != "" {
|
|
||||||
sb.WriteString("\n\n")
|
|
||||||
lineCnt += 1
|
|
||||||
|
|
||||||
// write message contents
|
|
||||||
var highlighted string
|
|
||||||
if m.highlightCache[i] == "" {
|
|
||||||
highlighted = message.Content
|
|
||||||
} else {
|
|
||||||
highlighted = m.highlightCache[i]
|
|
||||||
}
|
|
||||||
var contents string
|
|
||||||
if m.wrap {
|
|
||||||
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 2
|
|
||||||
wrapped := wordwrap.String(highlighted, wrapWidth)
|
|
||||||
contents = wrapped
|
|
||||||
} else {
|
|
||||||
contents = highlighted
|
|
||||||
}
|
|
||||||
sb.WriteString(messageStyle.Width(0).Render(contents))
|
|
||||||
lineCnt += lipgloss.Height(contents)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i < msgCnt-1 {
|
|
||||||
sb.WriteString("\n\n")
|
|
||||||
lineCnt += 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return conversationStyle.Render(sb.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func Launch(ctx *lmcli.Context, convShortname string) error {
|
program := tea.NewProgram(initialModel(ctx, *opts), tea.WithAltScreen())
|
||||||
p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen())
|
if _, err := program.Run(); err != nil {
|
||||||
if _, err := p.Run(); err != nil {
|
|
||||||
return fmt.Errorf("Error running program: %v", err)
|
return fmt.Errorf("Error running program: %v", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
package tui
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
)
|
|
||||||
|
|
||||||
type msgTempfileEditorClosed string
|
|
||||||
|
|
||||||
// openTempfileEditor opens an $EDITOR on a new temporary file with the given
|
|
||||||
// content. Upon closing, the contents of the file are read back returned
|
|
||||||
// wrapped in a msgTempfileEditorClosed returned by the tea.Cmd
|
|
||||||
func openTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
|
|
||||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
|
||||||
|
|
||||||
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
|
||||||
if err != nil {
|
|
||||||
return wrapError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
editor := os.Getenv("EDITOR")
|
|
||||||
if editor == "" {
|
|
||||||
editor = "vim"
|
|
||||||
}
|
|
||||||
|
|
||||||
c := exec.Command(editor, msgFile.Name())
|
|
||||||
return tea.ExecProcess(c, func(err error) tea.Msg {
|
|
||||||
bytes, err := os.ReadFile(msgFile.Name())
|
|
||||||
if err != nil {
|
|
||||||
return msgError(err)
|
|
||||||
}
|
|
||||||
fileContents := string(bytes)
|
|
||||||
if strings.HasPrefix(fileContents, placeholder) {
|
|
||||||
fileContents = fileContents[len(placeholder):]
|
|
||||||
}
|
|
||||||
stripped := strings.Trim(fileContents, "\n \t")
|
|
||||||
return msgTempfileEditorClosed(stripped)
|
|
||||||
})
|
|
||||||
}
|
|
137
pkg/tui/util/util.go
Normal file
137
pkg/tui/util/util.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
"github.com/muesli/reflow/ansi"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MsgTempfileEditorClosed string
|
||||||
|
|
||||||
|
// OpenTempfileEditor opens $EDITOR on a temporary file with the given content.
|
||||||
|
// Upon closing, the contents of the file are read and returned wrapped in a
|
||||||
|
// MsgTempfileEditorClosed
|
||||||
|
func OpenTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
|
||||||
|
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||||
|
|
||||||
|
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
||||||
|
if err != nil {
|
||||||
|
return func() tea.Msg { return err }
|
||||||
|
}
|
||||||
|
|
||||||
|
editor := os.Getenv("EDITOR")
|
||||||
|
if editor == "" {
|
||||||
|
editor = "vim"
|
||||||
|
}
|
||||||
|
|
||||||
|
c := exec.Command(editor, msgFile.Name())
|
||||||
|
return tea.ExecProcess(c, func(err error) tea.Msg {
|
||||||
|
bytes, err := os.ReadFile(msgFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
os.Remove(msgFile.Name())
|
||||||
|
fileContents := string(bytes)
|
||||||
|
if strings.HasPrefix(fileContents, placeholder) {
|
||||||
|
fileContents = fileContents[len(placeholder):]
|
||||||
|
}
|
||||||
|
stripped := strings.Trim(fileContents, "\n \t")
|
||||||
|
return MsgTempfileEditorClosed(stripped)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// similar to lipgloss.Height, except returns 0 instead of 1 on empty strings
|
||||||
|
func Height(str string) int {
|
||||||
|
if str == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return strings.Count(str, "\n") + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func Width(str string) int {
|
||||||
|
if str == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return ansi.PrintableRuneWidth(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TruncateRightToCellWidth(str string, width int, tail string) string {
|
||||||
|
cellWidth := ansi.PrintableRuneWidth(str)
|
||||||
|
if cellWidth <= width {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
tailWidth := ansi.PrintableRuneWidth(tail)
|
||||||
|
if width <= tailWidth {
|
||||||
|
return tail[:width]
|
||||||
|
}
|
||||||
|
|
||||||
|
targetWidth := width - tailWidth
|
||||||
|
runes := []rune(str)
|
||||||
|
|
||||||
|
for i := len(runes) - 1; i >= 0; i-- {
|
||||||
|
str = string(runes[:i])
|
||||||
|
if ansi.PrintableRuneWidth(str) <= targetWidth {
|
||||||
|
return str + tail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tail
|
||||||
|
}
|
||||||
|
|
||||||
|
func TruncateLeftToCellWidth(str string, width int, tail string) string {
|
||||||
|
cellWidth := ansi.PrintableRuneWidth(str)
|
||||||
|
if cellWidth <= width {
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
tailWidth := ansi.PrintableRuneWidth(tail)
|
||||||
|
if width <= tailWidth {
|
||||||
|
return tail[:width]
|
||||||
|
}
|
||||||
|
|
||||||
|
targetWidth := width - tailWidth
|
||||||
|
runes := []rune(str)
|
||||||
|
|
||||||
|
for i := 0; i < len(runes); i++ {
|
||||||
|
str = string(runes[i:])
|
||||||
|
if ansi.PrintableRuneWidth(str) <= targetWidth {
|
||||||
|
return tail + str
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tail
|
||||||
|
}
|
||||||
|
|
||||||
|
func ScrollIntoView(vp *viewport.Model, offset int, edge int) {
|
||||||
|
currentOffset := vp.YOffset
|
||||||
|
if offset >= currentOffset && offset < currentOffset+vp.Height {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
distance := currentOffset - offset
|
||||||
|
if distance < 0 {
|
||||||
|
// we should scroll down until it just comes into view
|
||||||
|
vp.SetYOffset(currentOffset - (distance + (vp.Height - edge)) + 1)
|
||||||
|
} else {
|
||||||
|
// we should scroll up
|
||||||
|
vp.SetYOffset(currentOffset - distance - edge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrorBanner(err error, width int) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return lipgloss.NewStyle().
|
||||||
|
Width(width).
|
||||||
|
AlignHorizontal(lipgloss.Center).
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("1")).
|
||||||
|
Render(fmt.Sprintf("%s", err))
|
||||||
|
}
|
168
pkg/tui/views/chat/chat.go
Normal file
168
pkg/tui/views/chat/chat.go
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/provider"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
|
"github.com/charmbracelet/bubbles/spinner"
|
||||||
|
"github.com/charmbracelet/bubbles/textarea"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
// custom tea.Msg types
|
||||||
|
type (
|
||||||
|
// sent when a new conversation title generated
|
||||||
|
msgConversationTitleGenerated string
|
||||||
|
// sent when the conversation has been persisted, triggers a reload of contents
|
||||||
|
msgConversationPersisted conversation.Conversation
|
||||||
|
msgMessagesPersisted []conversation.Message
|
||||||
|
// sent when a conversation's messages are laoded
|
||||||
|
msgConversationMessagesLoaded struct {
|
||||||
|
messages []conversation.Message
|
||||||
|
}
|
||||||
|
// a special case of common.MsgError that stops the response waiting animation
|
||||||
|
msgChatResponseError struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
// sent on each chunk received from LLM
|
||||||
|
msgChatResponseChunk provider.Chunk
|
||||||
|
// sent on each completed reply
|
||||||
|
msgChatResponse conversation.Message
|
||||||
|
// sent when the response is canceled
|
||||||
|
msgChatResponseCanceled struct{}
|
||||||
|
// sent when results from a tool call are returned
|
||||||
|
msgToolResults []api.ToolResult
|
||||||
|
// sent when the given message is made the new selected reply of its parent
|
||||||
|
msgSelectedReplyCycled *conversation.Message
|
||||||
|
// sent when the given message is made the new selected root of the current conversation
|
||||||
|
msgSelectedRootCycled *conversation.Message
|
||||||
|
// sent when a message's contents are updated and saved
|
||||||
|
msgMessageUpdated *conversation.Message
|
||||||
|
// sent when a message is cloned, with the cloned message
|
||||||
|
msgMessageCloned *conversation.Message
|
||||||
|
)
|
||||||
|
|
||||||
|
type focusState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
focusInput focusState = iota
|
||||||
|
focusMessages
|
||||||
|
)
|
||||||
|
|
||||||
|
type editorTarget int
|
||||||
|
|
||||||
|
const (
|
||||||
|
input editorTarget = iota
|
||||||
|
selectedMessage
|
||||||
|
)
|
||||||
|
|
||||||
|
type state int
|
||||||
|
|
||||||
|
const (
|
||||||
|
idle state = iota
|
||||||
|
loading
|
||||||
|
pendingResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
// App state
|
||||||
|
App *model.AppModel
|
||||||
|
Height int
|
||||||
|
Width int
|
||||||
|
|
||||||
|
// Chat view state
|
||||||
|
state state // current overall status of the view
|
||||||
|
selectedMessage int
|
||||||
|
editorTarget editorTarget
|
||||||
|
stopSignal chan struct{}
|
||||||
|
replyChan chan conversation.Message
|
||||||
|
chatReplyChunks chan provider.Chunk
|
||||||
|
persistence bool // whether we will save new messages in the conversation
|
||||||
|
|
||||||
|
// UI state
|
||||||
|
focus focusState
|
||||||
|
showDetails bool // whether various details are shown in the UI (e.g. system prompt, tool calls/results, message metadata)
|
||||||
|
wrap bool // whether message content is wrapped to viewport width
|
||||||
|
messageCache []string // cache of syntax highlighted and wrapped message content
|
||||||
|
messageOffsets []int
|
||||||
|
|
||||||
|
// ui elements
|
||||||
|
content viewport.Model
|
||||||
|
input textarea.Model
|
||||||
|
spinner spinner.Model
|
||||||
|
replyCursor cursor.Model // cursor to indicate incoming response
|
||||||
|
|
||||||
|
// metrics
|
||||||
|
tokenCount uint
|
||||||
|
startTime time.Time
|
||||||
|
elapsed time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSpinner() spinner.Model {
|
||||||
|
return spinner.New(spinner.WithSpinner(
|
||||||
|
spinner.Spinner{
|
||||||
|
Frames: []string{
|
||||||
|
"∙∙∙",
|
||||||
|
"●∙∙",
|
||||||
|
"●●∙",
|
||||||
|
"●●●",
|
||||||
|
"∙●●",
|
||||||
|
"∙∙●",
|
||||||
|
"∙∙∙",
|
||||||
|
"∙∙●",
|
||||||
|
"∙●●",
|
||||||
|
"●●●",
|
||||||
|
"●●∙",
|
||||||
|
"●∙∙",
|
||||||
|
},
|
||||||
|
FPS: 440 * time.Millisecond,
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func Chat(app *model.AppModel) *Model {
|
||||||
|
m := Model{
|
||||||
|
App: app,
|
||||||
|
|
||||||
|
state: idle,
|
||||||
|
persistence: true,
|
||||||
|
|
||||||
|
stopSignal: make(chan struct{}),
|
||||||
|
replyChan: make(chan conversation.Message),
|
||||||
|
chatReplyChunks: make(chan provider.Chunk),
|
||||||
|
|
||||||
|
wrap: true,
|
||||||
|
selectedMessage: -1,
|
||||||
|
|
||||||
|
content: viewport.New(0, 0),
|
||||||
|
input: textarea.New(),
|
||||||
|
spinner: getSpinner(),
|
||||||
|
replyCursor: cursor.New(),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.replyCursor.SetChar(" ")
|
||||||
|
m.replyCursor.Focus()
|
||||||
|
|
||||||
|
m.input.Focus()
|
||||||
|
m.input.MaxHeight = 0
|
||||||
|
m.input.CharLimit = 0
|
||||||
|
m.input.ShowLineNumbers = false
|
||||||
|
m.input.Placeholder = "Enter a message"
|
||||||
|
|
||||||
|
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
||||||
|
m.input.FocusedStyle.Base = inputFocusedStyle
|
||||||
|
m.input.BlurredStyle.Base = inputBlurredStyle
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Init() tea.Cmd {
|
||||||
|
return tea.Batch(
|
||||||
|
m.waitForResponseChunk(),
|
||||||
|
)
|
||||||
|
}
|
136
pkg/tui/views/chat/cmds.go
Normal file
136
pkg/tui/views/chat/cmds.go
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *Model) waitForResponseChunk() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
return msgChatResponseChunk(<-m.chatReplyChunks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) loadConversationMessages() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
messages, err := m.App.LoadConversationMessages()
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(err)
|
||||||
|
}
|
||||||
|
return msgConversationMessagesLoaded{messages}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) generateConversationTitle() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
title, err := m.App.GenerateConversationTitle(m.App.Messages)
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(err)
|
||||||
|
}
|
||||||
|
return msgConversationTitleGenerated(title)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) cloneMessage(message conversation.Message, selected bool) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
msg, err := m.App.CloneMessage(message, selected)
|
||||||
|
if err != nil {
|
||||||
|
return shared.WrapError(err)
|
||||||
|
}
|
||||||
|
return msgMessageCloned(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) updateMessageContent(message *conversation.Message) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
err := m.App.UpdateMessageContent(message)
|
||||||
|
if err != nil {
|
||||||
|
return shared.WrapError(err)
|
||||||
|
}
|
||||||
|
return msgMessageUpdated(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) cycleSelectedRoot(conv *conversation.Conversation, dir model.MessageCycleDirection) tea.Cmd {
|
||||||
|
if len(conv.RootMessages) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() tea.Msg {
|
||||||
|
nextRoot, err := m.App.CycleSelectedRoot(conv, dir)
|
||||||
|
if err != nil {
|
||||||
|
return shared.WrapError(err)
|
||||||
|
}
|
||||||
|
return msgSelectedRootCycled(nextRoot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) cycleSelectedReply(message *conversation.Message, dir model.MessageCycleDirection) tea.Cmd {
|
||||||
|
if len(message.Replies) < 2 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return func() tea.Msg {
|
||||||
|
nextReply, err := m.App.CycleSelectedReply(message, dir)
|
||||||
|
if err != nil {
|
||||||
|
return shared.WrapError(err)
|
||||||
|
}
|
||||||
|
return msgSelectedReplyCycled(nextReply)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) persistConversation() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
conversation, err := m.App.PersistConversation()
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(err)
|
||||||
|
}
|
||||||
|
return msgConversationPersisted(conversation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) persistMessages() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
messages, err := m.App.PersistMessages()
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(err)
|
||||||
|
}
|
||||||
|
return msgMessagesPersisted(messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
results, err := m.App.ExecuteToolCalls(toolCalls)
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(err)
|
||||||
|
}
|
||||||
|
return msgToolResults(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) promptLLM() tea.Cmd {
|
||||||
|
m.state = pendingResponse
|
||||||
|
m.spinner = getSpinner()
|
||||||
|
m.replyCursor.Blink = false
|
||||||
|
|
||||||
|
m.startTime = time.Now()
|
||||||
|
m.elapsed = 0
|
||||||
|
m.tokenCount = 0
|
||||||
|
|
||||||
|
return tea.Batch(
|
||||||
|
m.spinner.Tick,
|
||||||
|
func() tea.Msg {
|
||||||
|
resp, err := m.App.Prompt(m.App.Messages, m.chatReplyChunks, m.stopSignal)
|
||||||
|
if err != nil {
|
||||||
|
return msgChatResponseError{Err: err}
|
||||||
|
}
|
||||||
|
return msgChatResponse(*resp)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
201
pkg/tui/views/chat/input.go
Normal file
201
pkg/tui/views/chat/input.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
|
||||||
|
switch m.focus {
|
||||||
|
case focusInput:
|
||||||
|
cmd := m.handleInputKey(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
case focusMessages:
|
||||||
|
cmd := m.handleMessagesKey(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "esc":
|
||||||
|
if m.state == pendingResponse {
|
||||||
|
m.stopSignal <- struct{}{}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
return func() tea.Msg {
|
||||||
|
return shared.MsgViewChange(shared.ViewConversations)
|
||||||
|
}
|
||||||
|
case "ctrl+c":
|
||||||
|
if m.state == pendingResponse {
|
||||||
|
m.stopSignal <- struct{}{}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
case "ctrl+g":
|
||||||
|
if m.state == pendingResponse {
|
||||||
|
m.stopSignal <- struct{}{}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
return func() tea.Msg {
|
||||||
|
return shared.MsgViewChange(shared.ViewSettings)
|
||||||
|
}
|
||||||
|
case "ctrl+p":
|
||||||
|
m.persistence = !m.persistence
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+t":
|
||||||
|
m.showDetails = !m.showDetails
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+w":
|
||||||
|
m.wrap = !m.wrap
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+n":
|
||||||
|
m.App.NewConversation()
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) scrollSelection(dir int) {
|
||||||
|
if m.selectedMessage+dir < 0 || m.selectedMessage+dir >= len(m.App.Messages) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newIdx := m.selectedMessage
|
||||||
|
for i := newIdx + dir; i >= 0 && i < len(m.App.Messages); i += dir {
|
||||||
|
if !m.showDetails && m.App.Messages[i].Role.IsSystem() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newIdx = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if newIdx != m.selectedMessage {
|
||||||
|
m.selectedMessage = newIdx
|
||||||
|
m.updateContent()
|
||||||
|
}
|
||||||
|
yOffset := m.messageOffsets[m.selectedMessage]
|
||||||
|
tuiutil.ScrollIntoView(&m.content, yOffset, m.content.Height/2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMessagesKey handles input when the messages pane is focused
|
||||||
|
func (m *Model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
|
||||||
|
switch msg.String() {
|
||||||
|
case "tab", "enter":
|
||||||
|
m.focus = focusInput
|
||||||
|
m.updateContent()
|
||||||
|
m.input.Focus()
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "e":
|
||||||
|
if m.selectedMessage < len(m.App.Messages) {
|
||||||
|
m.editorTarget = selectedMessage
|
||||||
|
return tuiutil.OpenTempfileEditor(
|
||||||
|
"message.*.md",
|
||||||
|
m.App.Messages[m.selectedMessage].Content,
|
||||||
|
"# Edit the message below\n",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case "ctrl+k", "ctrl+up":
|
||||||
|
if m.selectedMessage > 0 {
|
||||||
|
m.scrollSelection(-1)
|
||||||
|
}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+j", "ctrl+down":
|
||||||
|
if m.selectedMessage < len(m.App.Messages)-1 {
|
||||||
|
m.scrollSelection(1)
|
||||||
|
}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+h", "ctrl+left", "ctrl+l", "ctrl+right":
|
||||||
|
dir := model.CyclePrev
|
||||||
|
if msg.String() == "ctrl+l" || msg.String() == "ctrl+right" {
|
||||||
|
dir = model.CycleNext
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
if m.selectedMessage == 0 {
|
||||||
|
cmd = m.cycleSelectedRoot(&m.App.Conversation, dir)
|
||||||
|
} else if m.selectedMessage > 0 {
|
||||||
|
cmd = m.cycleSelectedReply(&m.App.Messages[m.selectedMessage-1], dir)
|
||||||
|
}
|
||||||
|
return cmd
|
||||||
|
case "ctrl+r":
|
||||||
|
// prompt the model with all messages up to and including the selected message
|
||||||
|
if m.state == idle && m.selectedMessage < len(m.App.Messages) {
|
||||||
|
m.App.Messages = m.App.Messages[:m.selectedMessage+1]
|
||||||
|
m.messageCache = m.messageCache[:m.selectedMessage+1]
|
||||||
|
cmd := m.promptLLM()
|
||||||
|
m.updateContent()
|
||||||
|
m.content.GotoBottom()
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleInputKey handles input when the input textarea is focused
|
||||||
|
func (m *Model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
||||||
|
switch msg.String() {
|
||||||
|
case "esc":
|
||||||
|
m.focus = focusMessages
|
||||||
|
if len(m.App.Messages) > 0 {
|
||||||
|
if m.selectedMessage < 0 || m.selectedMessage >= len(m.App.Messages) {
|
||||||
|
m.selectedMessage = len(m.App.Messages) - 1
|
||||||
|
}
|
||||||
|
offset := m.messageOffsets[m.selectedMessage]
|
||||||
|
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
|
||||||
|
}
|
||||||
|
m.updateContent()
|
||||||
|
m.input.Blur()
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "ctrl+s":
|
||||||
|
if m.state != idle {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
input := strings.TrimSpace(m.input.Value())
|
||||||
|
if input == "" {
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(m.App.Messages) > 0 && m.App.Messages[len(m.App.Messages)-1].Role.IsUser() {
|
||||||
|
return shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addMessage(conversation.Message{
|
||||||
|
Role: api.MessageRoleUser,
|
||||||
|
Content: input,
|
||||||
|
})
|
||||||
|
|
||||||
|
m.input.SetValue("")
|
||||||
|
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
if m.persistence {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmds = append(cmds, m.promptLLM())
|
||||||
|
|
||||||
|
m.updateContent()
|
||||||
|
m.content.GotoBottom()
|
||||||
|
return tea.Batch(cmds...)
|
||||||
|
case "ctrl+e":
|
||||||
|
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
|
||||||
|
m.editorTarget = input
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
270
pkg/tui/views/chat/update.go
Normal file
270
pkg/tui/views/chat/update.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
|
"github.com/charmbracelet/bubbles/cursor"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *Model) setMessage(i int, msg conversation.Message) {
|
||||||
|
if i >= len(m.App.Messages) {
|
||||||
|
panic("i out of range")
|
||||||
|
}
|
||||||
|
m.App.Messages[i] = msg
|
||||||
|
m.messageCache[i] = m.renderMessage(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) addMessage(msg conversation.Message) {
|
||||||
|
m.App.Messages = append(m.App.Messages, msg)
|
||||||
|
m.messageCache = append(m.messageCache, m.renderMessage(len(m.App.Messages)-1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) setMessageContents(i int, content string) {
|
||||||
|
if i >= len(m.App.Messages) {
|
||||||
|
panic("i out of range")
|
||||||
|
}
|
||||||
|
m.App.Messages[i].Content = content
|
||||||
|
m.messageCache[i] = m.renderMessage(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) rebuildMessageCache() {
|
||||||
|
m.messageCache = make([]string, len(m.App.Messages))
|
||||||
|
for i := range m.App.Messages {
|
||||||
|
m.messageCache[i] = m.renderMessage(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) updateContent() {
|
||||||
|
atBottom := m.content.AtBottom()
|
||||||
|
m.content.SetContent(m.conversationMessagesView())
|
||||||
|
if atBottom {
|
||||||
|
m.content.GotoBottom()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
||||||
|
inputHandled := false
|
||||||
|
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
cmd := m.handleInput(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
inputHandled = true
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.Width, m.Height = msg.Width, msg.Height
|
||||||
|
m.content.Width = msg.Width
|
||||||
|
m.input.SetWidth(msg.Width - m.input.FocusedStyle.Base.GetHorizontalFrameSize())
|
||||||
|
if len(m.App.Messages) > 0 {
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
}
|
||||||
|
case shared.MsgViewEnter:
|
||||||
|
// wake up spinners and cursors
|
||||||
|
cmds = append(cmds, cursor.Blink, m.spinner.Tick)
|
||||||
|
|
||||||
|
// Refresh view
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
|
||||||
|
if m.App.Conversation.ID > 0 {
|
||||||
|
// (re)load conversation contents
|
||||||
|
cmds = append(cmds, m.loadConversationMessages())
|
||||||
|
}
|
||||||
|
case tuiutil.MsgTempfileEditorClosed:
|
||||||
|
contents := string(msg)
|
||||||
|
switch m.editorTarget {
|
||||||
|
case input:
|
||||||
|
m.input.SetValue(contents)
|
||||||
|
case selectedMessage:
|
||||||
|
toEdit := m.App.Messages[m.selectedMessage]
|
||||||
|
if toEdit.Content != contents {
|
||||||
|
toEdit.Content = contents
|
||||||
|
m.setMessage(m.selectedMessage, toEdit)
|
||||||
|
if m.persistence && toEdit.ID > 0 {
|
||||||
|
// create clone of message with its new contents
|
||||||
|
cmds = append(cmds, m.cloneMessage(toEdit, true))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case msgConversationMessagesLoaded:
|
||||||
|
m.App.Messages = msg.messages
|
||||||
|
if m.selectedMessage == -1 {
|
||||||
|
m.selectedMessage = len(msg.messages) - 1
|
||||||
|
} else {
|
||||||
|
m.selectedMessage = min(m.selectedMessage, len(m.App.Messages))
|
||||||
|
}
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
case msgChatResponseChunk:
|
||||||
|
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
||||||
|
|
||||||
|
if msg.Content == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
last := len(m.App.Messages) - 1
|
||||||
|
if last >= 0 && m.App.Messages[last].Role.IsAssistant() {
|
||||||
|
// append chunk to existing message
|
||||||
|
m.setMessageContents(last, m.App.Messages[last].Content+msg.Content)
|
||||||
|
} else {
|
||||||
|
// use chunk in a new message
|
||||||
|
m.addMessage(conversation.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Content: msg.Content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
m.updateContent()
|
||||||
|
|
||||||
|
// show cursor and reset blink interval (simulate typing)
|
||||||
|
m.replyCursor.Blink = false
|
||||||
|
cmds = append(cmds, m.replyCursor.BlinkCmd())
|
||||||
|
|
||||||
|
m.tokenCount += msg.TokenCount
|
||||||
|
m.elapsed = time.Now().Sub(m.startTime)
|
||||||
|
case msgChatResponse:
|
||||||
|
m.state = idle
|
||||||
|
|
||||||
|
reply := conversation.Message(msg)
|
||||||
|
reply.Content = strings.TrimSpace(reply.Content)
|
||||||
|
|
||||||
|
last := len(m.App.Messages) - 1
|
||||||
|
if last < 0 {
|
||||||
|
panic("Unexpected empty messages handling msgAssistantReply")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.App.Messages[last].Role.IsAssistant() {
|
||||||
|
// TODO: handle continuations gracefully - only some models support them
|
||||||
|
m.setMessage(last, reply)
|
||||||
|
} else {
|
||||||
|
m.addMessage(reply)
|
||||||
|
}
|
||||||
|
|
||||||
|
if reply.Role == api.MessageRoleToolCall {
|
||||||
|
// TODO: user confirmation before execution
|
||||||
|
// m.state = confirmToolUse
|
||||||
|
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.persistence {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.App.Conversation.Title == "" && len(m.App.Messages) > 0 {
|
||||||
|
cmds = append(cmds, m.generateConversationTitle())
|
||||||
|
}
|
||||||
|
case msgChatResponseCanceled:
|
||||||
|
m.state = idle
|
||||||
|
m.updateContent()
|
||||||
|
case msgChatResponseError:
|
||||||
|
m.state = idle
|
||||||
|
m.updateContent()
|
||||||
|
return m, shared.WrapError(msg.Err)
|
||||||
|
case msgToolResults:
|
||||||
|
last := len(m.App.Messages) - 1
|
||||||
|
if last < 0 {
|
||||||
|
panic("Unexpected empty messages handling msgAssistantReply")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.App.Messages[last].Role != api.MessageRoleToolCall {
|
||||||
|
panic("Previous message not a tool call, unexpected")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addMessage(conversation.Message{
|
||||||
|
Role: api.MessageRoleToolResult,
|
||||||
|
ToolResults: conversation.ToolResults(msg),
|
||||||
|
})
|
||||||
|
|
||||||
|
if m.persistence {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateContent()
|
||||||
|
case msgConversationTitleGenerated:
|
||||||
|
title := string(msg)
|
||||||
|
m.App.Conversation.Title = title
|
||||||
|
if m.persistence && m.App.Conversation.ID > 0 {
|
||||||
|
cmds = append(cmds, m.persistConversation())
|
||||||
|
}
|
||||||
|
case cursor.BlinkMsg:
|
||||||
|
if m.state == pendingResponse {
|
||||||
|
// ensure we show the updated "wait for response" cursor blink state
|
||||||
|
last := len(m.App.Messages) - 1
|
||||||
|
m.messageCache[last] = m.renderMessage(last)
|
||||||
|
m.updateContent()
|
||||||
|
}
|
||||||
|
case msgConversationPersisted:
|
||||||
|
m.App.Conversation = conversation.Conversation(msg)
|
||||||
|
cmds = append(cmds, m.persistMessages())
|
||||||
|
case msgMessagesPersisted:
|
||||||
|
m.App.Messages = msg
|
||||||
|
m.rebuildMessageCache()
|
||||||
|
m.updateContent()
|
||||||
|
case msgMessageCloned:
|
||||||
|
cmds = append(cmds, m.loadConversationMessages())
|
||||||
|
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
|
||||||
|
cmds = append(cmds, m.loadConversationMessages())
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.spinner, cmd = m.spinner.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
m.replyCursor, cmd = m.replyCursor.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
prevInputLineCnt := m.input.LineCount()
|
||||||
|
|
||||||
|
if !inputHandled {
|
||||||
|
m.input, cmd = m.input.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
inputHandled = true
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inputHandled {
|
||||||
|
m.content, cmd = m.content.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is a pretty nasty hack to ensure the input area viewport doesn't
|
||||||
|
// scroll below its content, which can happen when the input viewport
|
||||||
|
// height has grown, or previously entered lines have been deleted
|
||||||
|
if prevInputLineCnt != m.input.LineCount() {
|
||||||
|
// dist is the distance we'd need to scroll up from the current cursor
|
||||||
|
// position to position the last input line at the bottom of the
|
||||||
|
// viewport. if negative, we're already scrolled above the bottom
|
||||||
|
dist := m.input.Line() - (m.input.LineCount() - m.input.Height())
|
||||||
|
if dist > 0 {
|
||||||
|
for i := 0; i < dist; i++ {
|
||||||
|
// move cursor up until content reaches the bottom of the viewport
|
||||||
|
m.input.CursorUp()
|
||||||
|
}
|
||||||
|
m.input, _ = m.input.Update(nil)
|
||||||
|
for i := 0; i < dist; i++ {
|
||||||
|
// move cursor back down to its previous position
|
||||||
|
m.input.CursorDown()
|
||||||
|
}
|
||||||
|
m.input, _ = m.input.Update(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmds) > 0 {
|
||||||
|
return m, tea.Batch(cmds...)
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
365
pkg/tui/views/chat/view.go
Normal file
365
pkg/tui/views/chat/view.go
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
package chat
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
"github.com/muesli/reflow/wordwrap"
|
||||||
|
"github.com/muesli/reflow/wrap"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// styles
|
||||||
|
var (
|
||||||
|
boldStyle = lipgloss.NewStyle().Bold(true)
|
||||||
|
faintStyle = lipgloss.NewStyle().Faint(true)
|
||||||
|
boldFaintStyle = lipgloss.NewStyle().Faint(true).Bold(true)
|
||||||
|
|
||||||
|
messageHeadingStyle = lipgloss.NewStyle().
|
||||||
|
MarginTop(1).
|
||||||
|
MarginBottom(1)
|
||||||
|
|
||||||
|
userStyle = boldFaintStyle.Foreground(lipgloss.Color("10"))
|
||||||
|
|
||||||
|
assistantStyle = boldFaintStyle.Foreground(lipgloss.Color("12"))
|
||||||
|
|
||||||
|
systemStyle = boldStyle.Foreground(lipgloss.Color("8"))
|
||||||
|
|
||||||
|
messageStyle = lipgloss.NewStyle().
|
||||||
|
PaddingLeft(2).
|
||||||
|
PaddingRight(2)
|
||||||
|
|
||||||
|
inputFocusedStyle = lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder(), true, true, true, false)
|
||||||
|
|
||||||
|
inputBlurredStyle = lipgloss.NewStyle().
|
||||||
|
Faint(true).
|
||||||
|
Border(lipgloss.RoundedBorder(), true, true, true, false)
|
||||||
|
|
||||||
|
footerStyle = lipgloss.NewStyle().Padding(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *Model) renderMessageHeading(i int, message *conversation.Message) string {
|
||||||
|
friendly := message.Role.FriendlyRole()
|
||||||
|
style := systemStyle
|
||||||
|
|
||||||
|
switch message.Role {
|
||||||
|
case api.MessageRoleUser:
|
||||||
|
style = userStyle
|
||||||
|
case api.MessageRoleAssistant:
|
||||||
|
style = assistantStyle
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
style = assistantStyle
|
||||||
|
friendly = api.MessageRoleAssistant.FriendlyRole()
|
||||||
|
case api.MessageRoleSystem:
|
||||||
|
case api.MessageRoleToolResult:
|
||||||
|
}
|
||||||
|
|
||||||
|
user := style.Render(friendly)
|
||||||
|
|
||||||
|
var prefix, suffix string
|
||||||
|
|
||||||
|
if i == m.selectedMessage && m.focus == focusMessages {
|
||||||
|
prefix = "> "
|
||||||
|
} else {
|
||||||
|
prefix = " "
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 && m.App.Conversation.SelectedRootID != nil && len(m.App.Conversation.RootMessages) > 1 {
|
||||||
|
selectedRootIndex := 0
|
||||||
|
for j, reply := range m.App.Conversation.RootMessages {
|
||||||
|
if reply.ID == *m.App.Conversation.SelectedRootID {
|
||||||
|
selectedRootIndex = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.App.Conversation.RootMessages)))
|
||||||
|
}
|
||||||
|
if i > 0 && len(m.App.Messages[i-1].Replies) > 1 {
|
||||||
|
// Find the selected reply index
|
||||||
|
selectedReplyIndex := 0
|
||||||
|
for j, reply := range m.App.Messages[i-1].Replies {
|
||||||
|
if reply.ID == *m.App.Messages[i-1].SelectedReplyID {
|
||||||
|
selectedReplyIndex = j
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
suffix += faintStyle.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.App.Messages[i-1].Replies)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.ID == 0 {
|
||||||
|
suffix += faintStyle.Render(" (not saved)")
|
||||||
|
}
|
||||||
|
|
||||||
|
heading := prefix + user + suffix
|
||||||
|
|
||||||
|
if message.Metadata.GenerationModel != nil && m.showDetails {
|
||||||
|
heading += faintStyle.Render(
|
||||||
|
fmt.Sprintf(" | %s", *message.Metadata.GenerationModel),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messageHeadingStyle.Render(heading)
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderMessages renders the message at the given index as it should be shown
|
||||||
|
// *at this moment* - we render differently depending on the current application
|
||||||
|
// state (window size, etc, etc).
|
||||||
|
func (m *Model) renderMessage(i int) string {
|
||||||
|
msg := &m.App.Messages[i]
|
||||||
|
|
||||||
|
// Write message contents
|
||||||
|
sb := &strings.Builder{}
|
||||||
|
sb.Grow(len(msg.Content) * 2)
|
||||||
|
if msg.Content != "" {
|
||||||
|
err := m.App.Ctx.Chroma.Highlight(sb, msg.Content)
|
||||||
|
if err != nil {
|
||||||
|
sb.Reset()
|
||||||
|
sb.WriteString(msg.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isLast := i == len(m.App.Messages)-1
|
||||||
|
isAssistant := msg.Role == api.MessageRoleAssistant
|
||||||
|
|
||||||
|
if m.state == pendingResponse && isLast && isAssistant {
|
||||||
|
// Show the assistant's cursor
|
||||||
|
sb.WriteString(m.replyCursor.View())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write tool call info
|
||||||
|
var toolString string
|
||||||
|
switch msg.Role {
|
||||||
|
case api.MessageRoleToolCall:
|
||||||
|
bytes, err := yaml.Marshal(msg.ToolCalls)
|
||||||
|
if err != nil {
|
||||||
|
toolString = "Could not serialize ToolCalls"
|
||||||
|
} else {
|
||||||
|
toolString = "tool_calls:\n" + string(bytes)
|
||||||
|
}
|
||||||
|
case api.MessageRoleToolResult:
|
||||||
|
type renderedResult struct {
|
||||||
|
ToolName string `yaml:"tool"`
|
||||||
|
Result any `yaml:"result,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolResults []renderedResult
|
||||||
|
for _, result := range msg.ToolResults {
|
||||||
|
if m.showDetails {
|
||||||
|
var jsonResult interface{}
|
||||||
|
err := json.Unmarshal([]byte(result.Result), &jsonResult)
|
||||||
|
if err != nil {
|
||||||
|
// If parsing as JSON fails, treat Result as a plain string
|
||||||
|
toolResults = append(toolResults, renderedResult{
|
||||||
|
ToolName: result.ToolName,
|
||||||
|
Result: result.Result,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// If parsing as JSON succeeds, marshal the parsed JSON into YAML
|
||||||
|
toolResults = append(toolResults, renderedResult{
|
||||||
|
ToolName: result.ToolName,
|
||||||
|
Result: &jsonResult,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Only show the tool name when results are hidden
|
||||||
|
toolResults = append(toolResults, renderedResult{
|
||||||
|
ToolName: result.ToolName,
|
||||||
|
Result: "(hidden, press ctrl+t to view)",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes, err := yaml.Marshal(toolResults)
|
||||||
|
if err != nil {
|
||||||
|
toolString = "Could not serialize ToolResults"
|
||||||
|
} else {
|
||||||
|
toolString = "tool_results:\n" + string(bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if toolString != "" {
|
||||||
|
toolString = strings.TrimRight(toolString, "\n")
|
||||||
|
if msg.Content != "" {
|
||||||
|
sb.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
_ = m.App.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
|
||||||
|
}
|
||||||
|
|
||||||
|
content := strings.TrimRight(sb.String(), "\n")
|
||||||
|
|
||||||
|
if m.wrap {
|
||||||
|
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding()
|
||||||
|
// first we word-wrap text to slightly less than desired width (since
|
||||||
|
// wordwrap seems to have an off-by-1 issue), then hard wrap at
|
||||||
|
// desired with
|
||||||
|
content = wrap.String(wordwrap.String(content, wrapWidth-2), wrapWidth)
|
||||||
|
}
|
||||||
|
|
||||||
|
return messageStyle.Width(0).Render(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// render the conversation into a string
|
||||||
|
func (m *Model) conversationMessagesView() string {
|
||||||
|
m.messageOffsets = make([]int, len(m.App.Messages))
|
||||||
|
lineCnt := 1
|
||||||
|
|
||||||
|
sb := strings.Builder{}
|
||||||
|
for i, message := range m.App.Messages {
|
||||||
|
m.messageOffsets[i] = lineCnt
|
||||||
|
|
||||||
|
if !m.showDetails && message.Role.IsSystem() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
heading := m.renderMessageHeading(i, &message)
|
||||||
|
sb.WriteString(heading)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
lineCnt += lipgloss.Height(heading)
|
||||||
|
|
||||||
|
rendered := m.messageCache[i]
|
||||||
|
sb.WriteString(rendered)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
lineCnt += lipgloss.Height(rendered)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render a placeholder for the incoming assistant reply
|
||||||
|
if m.state == pendingResponse && m.App.Messages[len(m.App.Messages)-1].Role != api.MessageRoleAssistant {
|
||||||
|
heading := m.renderMessageHeading(-1, &conversation.Message{
|
||||||
|
Role: api.MessageRoleAssistant,
|
||||||
|
Metadata: conversation.MessageMeta{
|
||||||
|
GenerationModel: &m.App.Model,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
sb.WriteString(heading)
|
||||||
|
sb.WriteString("\n")
|
||||||
|
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Content(width, height int) string {
|
||||||
|
// calculate clamped input height to accomodate input text
|
||||||
|
// minimum 4 lines, maximum half of content area
|
||||||
|
inputHeight := max(4, min(height/2, m.input.LineCount()))
|
||||||
|
m.input.SetHeight(inputHeight)
|
||||||
|
input := m.input.View()
|
||||||
|
|
||||||
|
// remaining height towards content
|
||||||
|
m.content.Width, m.content.Height = width, height-tuiutil.Height(input)
|
||||||
|
content := m.content.View()
|
||||||
|
return lipgloss.JoinVertical(lipgloss.Left, content, input)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Header(width int) string {
|
||||||
|
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
var title string
|
||||||
|
if m.App.Conversation.Title != "" {
|
||||||
|
title = m.App.Conversation.Title
|
||||||
|
} else {
|
||||||
|
title = "Untitled"
|
||||||
|
}
|
||||||
|
title = tuiutil.TruncateRightToCellWidth(title, width-styles.Header.GetHorizontalPadding(), "...")
|
||||||
|
header := titleStyle.Render(title)
|
||||||
|
return styles.Header.Width(width).Render(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Footer(width int) string {
|
||||||
|
segmentStyle := lipgloss.NewStyle().Faint(true)
|
||||||
|
segmentSeparator := segmentStyle.Render(" | ")
|
||||||
|
|
||||||
|
// Left segments
|
||||||
|
leftSegments := make([]string, 0, 4)
|
||||||
|
|
||||||
|
if m.state == pendingResponse {
|
||||||
|
leftSegments = append(leftSegments, segmentStyle.Render(m.spinner.View()))
|
||||||
|
} else {
|
||||||
|
leftSegments = append(leftSegments, segmentStyle.Render("∙∙∙"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.elapsed > 0 && m.tokenCount > 0 {
|
||||||
|
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
|
||||||
|
leftSegments = append(leftSegments, segmentStyle.Render(throughput))
|
||||||
|
}
|
||||||
|
|
||||||
|
// var status string
|
||||||
|
// switch m.state {
|
||||||
|
// case pendingResponse:
|
||||||
|
// status = "Press ctrl+c to cancel"
|
||||||
|
// default:
|
||||||
|
// status = "Press ctrl+s to send"
|
||||||
|
// }
|
||||||
|
// leftSegments = append(leftSegments, segmentStyle.Render(status))
|
||||||
|
|
||||||
|
// Right segments
|
||||||
|
rightSegments := make([]string, 0, 8)
|
||||||
|
|
||||||
|
if m.App.Agent != nil {
|
||||||
|
rightSegments = append(rightSegments, segmentStyle.Render(m.App.Agent.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
model := segmentStyle.Render(m.App.ActiveModel(lipgloss.NewStyle()))
|
||||||
|
rightSegments = append(rightSegments, model)
|
||||||
|
|
||||||
|
savingStyle := segmentStyle.Bold(true)
|
||||||
|
saving := ""
|
||||||
|
if m.persistence {
|
||||||
|
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("💾✅")
|
||||||
|
} else {
|
||||||
|
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("💾❌")
|
||||||
|
}
|
||||||
|
rightSegments = append(rightSegments, saving)
|
||||||
|
|
||||||
|
return m.layoutFooter(width, leftSegments, rightSegments, segmentSeparator)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) layoutFooter(
|
||||||
|
width int,
|
||||||
|
leftSegments []string,
|
||||||
|
rightSegments []string,
|
||||||
|
segmentSeparator string,
|
||||||
|
) string {
|
||||||
|
left := strings.Join(leftSegments, segmentSeparator)
|
||||||
|
right := strings.Join(rightSegments, segmentSeparator)
|
||||||
|
|
||||||
|
leftWidth := tuiutil.Width(left)
|
||||||
|
rightWidth := tuiutil.Width(right)
|
||||||
|
sepWidth := tuiutil.Width(segmentSeparator)
|
||||||
|
frameWidth := footerStyle.GetHorizontalFrameSize()
|
||||||
|
|
||||||
|
availableWidth := width - frameWidth - leftWidth - rightWidth
|
||||||
|
|
||||||
|
if availableWidth >= sepWidth {
|
||||||
|
// Everything fits
|
||||||
|
padding := strings.Repeat(" ", availableWidth)
|
||||||
|
return footerStyle.Render(left + padding + right)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inserted between left and right segments when they're being truncated
|
||||||
|
div := "..."
|
||||||
|
|
||||||
|
totalAvailableWidth := width - frameWidth
|
||||||
|
availableTruncWidth := totalAvailableWidth - len(div)
|
||||||
|
|
||||||
|
minVisibleLength := 3
|
||||||
|
if availableTruncWidth < 2*minVisibleLength {
|
||||||
|
minVisibleLength = availableTruncWidth / 2
|
||||||
|
}
|
||||||
|
|
||||||
|
leftProportion := float64(leftWidth) / float64(leftWidth+rightWidth)
|
||||||
|
|
||||||
|
newLeftWidth := int(max(float64(minVisibleLength), leftProportion*float64(availableTruncWidth)))
|
||||||
|
newRightWidth := totalAvailableWidth - newLeftWidth
|
||||||
|
|
||||||
|
truncatedLeft := faintStyle.Render(tuiutil.TruncateRightToCellWidth(left, newLeftWidth, ""))
|
||||||
|
truncatedRight := faintStyle.Render(tuiutil.TruncateLeftToCellWidth(right, newRightWidth, "..."))
|
||||||
|
|
||||||
|
return footerStyle.Width(width).Render(truncatedLeft + truncatedRight)
|
||||||
|
}
|
326
pkg/tui/views/conversations/conversations.go
Normal file
326
pkg/tui/views/conversations/conversations.go
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
package conversations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/conversation"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
|
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
// sent when conversation list is loaded
|
||||||
|
msgConversationsLoaded conversation.ConversationList
|
||||||
|
// sent when a single conversation is loaded
|
||||||
|
msgConversationLoaded conversation.Conversation
|
||||||
|
// sent when a conversation is deleted
|
||||||
|
msgConversationDeleted struct{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
App *model.AppModel
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
|
||||||
|
cursor int
|
||||||
|
itemOffsets []int // conversation y offsets
|
||||||
|
|
||||||
|
content viewport.Model
|
||||||
|
confirmPrompt bubbles.ConfirmPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
func Conversations(app *model.AppModel) *Model {
|
||||||
|
viewport.New(0, 0)
|
||||||
|
m := Model{
|
||||||
|
App: app,
|
||||||
|
content: viewport.New(0, 0),
|
||||||
|
}
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) handleInput(msg tea.KeyMsg) tea.Cmd {
|
||||||
|
if m.confirmPrompt.Focused() {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
m.confirmPrompt, cmd = m.confirmPrompt.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conversations := m.App.Conversations.Items
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "enter":
|
||||||
|
if len(conversations) > 0 && m.cursor < len(conversations) {
|
||||||
|
return m.loadConversation(conversations[m.cursor].ID)
|
||||||
|
}
|
||||||
|
case "j", "down":
|
||||||
|
if m.cursor < len(conversations)-1 {
|
||||||
|
m.cursor++
|
||||||
|
if m.cursor == len(conversations)-1 {
|
||||||
|
m.content.GotoBottom()
|
||||||
|
} else {
|
||||||
|
// this hack positions the *next* conversatoin slightly
|
||||||
|
// *off* the screen, ensuring the entire m.cursor is shown,
|
||||||
|
// even if its height may not be constant due to wrapping.
|
||||||
|
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor+1], -1)
|
||||||
|
}
|
||||||
|
m.content.SetContent(m.renderConversationList())
|
||||||
|
} else {
|
||||||
|
m.cursor = len(conversations) - 1
|
||||||
|
m.content.GotoBottom()
|
||||||
|
}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "k", "up":
|
||||||
|
if m.cursor > 0 {
|
||||||
|
m.cursor--
|
||||||
|
if m.cursor == 0 {
|
||||||
|
m.content.GotoTop()
|
||||||
|
} else {
|
||||||
|
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor], 1)
|
||||||
|
}
|
||||||
|
m.content.SetContent(m.renderConversationList())
|
||||||
|
} else {
|
||||||
|
m.cursor = 0
|
||||||
|
m.content.GotoTop()
|
||||||
|
}
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
case "n":
|
||||||
|
m.App.NewConversation()
|
||||||
|
return shared.ChangeView(shared.ViewChat)
|
||||||
|
case "d":
|
||||||
|
if !m.confirmPrompt.Focused() && len(conversations) > 0 && m.cursor < len(conversations) {
|
||||||
|
title := conversations[m.cursor].Title
|
||||||
|
if title == "" {
|
||||||
|
title = "(untitled)"
|
||||||
|
}
|
||||||
|
m.confirmPrompt = bubbles.NewConfirmPrompt(
|
||||||
|
fmt.Sprintf("Delete '%s'?", title),
|
||||||
|
conversations[m.cursor],
|
||||||
|
)
|
||||||
|
m.confirmPrompt.Style = lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("3"))
|
||||||
|
return shared.KeyHandled(msg)
|
||||||
|
}
|
||||||
|
case "c":
|
||||||
|
// copy/clone conversation
|
||||||
|
case "r":
|
||||||
|
// show prompt to rename conversation
|
||||||
|
case "shift+r":
|
||||||
|
// show prompt to generate name for conversation
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Init() tea.Cmd {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
||||||
|
isInput := false
|
||||||
|
inputHandled := false
|
||||||
|
|
||||||
|
var cmds []tea.Cmd
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
isInput = true
|
||||||
|
cmd := m.handleInput(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
inputHandled = true
|
||||||
|
}
|
||||||
|
case shared.MsgViewEnter:
|
||||||
|
cmds = append(cmds, m.loadConversations())
|
||||||
|
m.content.SetContent(m.renderConversationList())
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.width, m.height = msg.Width, msg.Height
|
||||||
|
m.content.SetContent(m.renderConversationList())
|
||||||
|
case msgConversationsLoaded:
|
||||||
|
m.App.Conversations = conversation.ConversationList(msg)
|
||||||
|
m.cursor = max(0, min(len(m.App.Conversations.Items), m.cursor))
|
||||||
|
m.content.SetContent(m.renderConversationList())
|
||||||
|
case msgConversationLoaded:
|
||||||
|
m.App.ClearConversation()
|
||||||
|
m.App.Conversation = conversation.Conversation(msg)
|
||||||
|
cmds = append(cmds, func() tea.Msg {
|
||||||
|
return shared.MsgViewChange(shared.ViewChat)
|
||||||
|
})
|
||||||
|
case bubbles.MsgConfirmPromptAnswered:
|
||||||
|
m.confirmPrompt.Blur()
|
||||||
|
if msg.Value {
|
||||||
|
conv, ok := msg.Payload.(conversation.ConversationListItem)
|
||||||
|
if ok {
|
||||||
|
cmds = append(cmds, m.deleteConversation(conv))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case msgConversationDeleted:
|
||||||
|
cmds = append(cmds, m.loadConversations())
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isInput || !inputHandled {
|
||||||
|
content, cmd := m.content.Update(msg)
|
||||||
|
m.content = content
|
||||||
|
if cmd != nil {
|
||||||
|
cmds = append(cmds, cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cmds) > 0 {
|
||||||
|
return m, tea.Batch(cmds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) loadConversations() tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
list, err := m.App.Ctx.Conversations.LoadConversationList()
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||||
|
}
|
||||||
|
return msgConversationsLoaded(list)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) loadConversation(conversationID uint) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
conversation, err := m.App.Ctx.Conversations.GetConversationByID(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(fmt.Errorf("Could not load conversation %d: %v", conversationID, err))
|
||||||
|
}
|
||||||
|
return msgConversationLoaded(*conversation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) deleteConversation(conv conversation.ConversationListItem) tea.Cmd {
|
||||||
|
return func() tea.Msg {
|
||||||
|
err := m.App.Ctx.Conversations.DeleteConversationById(conv.ID)
|
||||||
|
if err != nil {
|
||||||
|
return shared.AsMsgError(fmt.Errorf("Could not delete conversation: %v", err))
|
||||||
|
}
|
||||||
|
return msgConversationDeleted{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Header(width int) string {
|
||||||
|
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
header := titleStyle.Render("Conversations")
|
||||||
|
return styles.Header.Width(width).Render(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Content(width int, height int) string {
|
||||||
|
m.content.Width, m.content.Height = width, height
|
||||||
|
return m.content.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Footer(width int) string {
|
||||||
|
if m.confirmPrompt.Focused() {
|
||||||
|
return lipgloss.NewStyle().Width(width).Render(m.confirmPrompt.View())
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) renderConversationList() string {
|
||||||
|
now := time.Now()
|
||||||
|
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||||
|
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||||
|
dayOfWeek := int(now.Weekday())
|
||||||
|
categories := []struct {
|
||||||
|
name string
|
||||||
|
cutoff time.Duration
|
||||||
|
}{
|
||||||
|
{"Today", now.Sub(midnight)},
|
||||||
|
{"Yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||||
|
{"This week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||||
|
{"Last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||||
|
{"This month", now.Sub(monthStart)},
|
||||||
|
{"Last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||||
|
{"2 Months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||||
|
{"3 Months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||||
|
{"4 Months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||||
|
{"5 Months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||||
|
{"6 Months ago", now.Sub(monthStart.AddDate(0, -6, 0))},
|
||||||
|
{"Older", now.Sub(time.Time{})},
|
||||||
|
}
|
||||||
|
|
||||||
|
categoryStyle := lipgloss.NewStyle().
|
||||||
|
MarginBottom(1).
|
||||||
|
Foreground(lipgloss.Color("12")).
|
||||||
|
PaddingLeft(1).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
itemStyle := lipgloss.NewStyle().
|
||||||
|
MarginBottom(1)
|
||||||
|
|
||||||
|
ageStyle := lipgloss.NewStyle().Faint(true).SetString()
|
||||||
|
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
untitledStyle := lipgloss.NewStyle().Faint(true).Italic(true)
|
||||||
|
selectedStyle := lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6"))
|
||||||
|
|
||||||
|
var (
|
||||||
|
currentOffset int
|
||||||
|
currentCategory string
|
||||||
|
sb strings.Builder
|
||||||
|
)
|
||||||
|
|
||||||
|
m.itemOffsets = make([]int, len(m.App.Conversations.Items))
|
||||||
|
sb.WriteRune('\n')
|
||||||
|
currentOffset += 1
|
||||||
|
|
||||||
|
for i, c := range m.App.Conversations.Items {
|
||||||
|
lastReplyAge := now.Sub(c.LastMessageAt)
|
||||||
|
|
||||||
|
var category string
|
||||||
|
for _, g := range categories {
|
||||||
|
if lastReplyAge < g.cutoff {
|
||||||
|
category = g.name
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// print the category
|
||||||
|
if category != currentCategory {
|
||||||
|
currentCategory = category
|
||||||
|
heading := categoryStyle.Render(currentCategory)
|
||||||
|
sb.WriteString(heading)
|
||||||
|
currentOffset += tuiutil.Height(heading)
|
||||||
|
sb.WriteRune('\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
tStyle := titleStyle
|
||||||
|
if c.Title == "" {
|
||||||
|
tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)")
|
||||||
|
}
|
||||||
|
if i == m.cursor {
|
||||||
|
tStyle = tStyle.Inherit(selectedStyle)
|
||||||
|
}
|
||||||
|
|
||||||
|
title := tStyle.Width(m.width - 3).PaddingLeft(2).Render(c.Title)
|
||||||
|
if i == m.cursor {
|
||||||
|
title = ">" + title[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
m.itemOffsets[i] = currentOffset
|
||||||
|
item := itemStyle.Render(fmt.Sprintf(
|
||||||
|
"%s\n %s",
|
||||||
|
title,
|
||||||
|
ageStyle.Render(util.HumanTimeElapsedSince(lastReplyAge)),
|
||||||
|
))
|
||||||
|
sb.WriteString(item)
|
||||||
|
currentOffset += tuiutil.Height(item)
|
||||||
|
if i < len(m.App.Conversations.Items)-1 {
|
||||||
|
sb.WriteRune('\n')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
137
pkg/tui/views/settings/settings.go
Normal file
137
pkg/tui/views/settings/settings.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package settings
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles/list"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/model"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||||
|
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||||
|
"github.com/charmbracelet/bubbles/viewport"
|
||||||
|
tea "github.com/charmbracelet/bubbletea"
|
||||||
|
"github.com/charmbracelet/lipgloss"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
App *model.AppModel
|
||||||
|
prevView shared.View
|
||||||
|
content viewport.Model
|
||||||
|
modelList list.Model
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
}
|
||||||
|
|
||||||
|
type modelOpt struct {
|
||||||
|
provider string
|
||||||
|
model string
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelListId int = iota + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
func Settings(app *model.AppModel) *Model {
|
||||||
|
m := &Model{
|
||||||
|
App: app,
|
||||||
|
content: viewport.New(0, 0),
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Init() tea.Cmd {
|
||||||
|
m.modelList = list.NewWithGroups(m.getModelOptions())
|
||||||
|
m.modelList.ID = modelListId
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Update(msg tea.Msg) (shared.ViewModel, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
|
||||||
|
switch msg := msg.(type) {
|
||||||
|
case tea.KeyMsg:
|
||||||
|
m.modelList, cmd = m.modelList.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.String() {
|
||||||
|
case "esc":
|
||||||
|
return m, func() tea.Msg {
|
||||||
|
return shared.MsgViewChange(m.prevView)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case shared.MsgViewEnter:
|
||||||
|
m.prevView = shared.View(msg)
|
||||||
|
m.modelList.Focus()
|
||||||
|
m.content.SetContent(m.renderContent())
|
||||||
|
case tea.WindowSizeMsg:
|
||||||
|
m.width, m.height = msg.Width, msg.Height
|
||||||
|
m.content.Width = msg.Width
|
||||||
|
m.content.Height = msg.Height
|
||||||
|
m.content.SetContent(m.renderContent())
|
||||||
|
case list.MsgOptionSelected:
|
||||||
|
switch msg.ID {
|
||||||
|
case modelListId:
|
||||||
|
if modelOpt, ok := msg.Option.Value.(modelOpt); ok {
|
||||||
|
m.App.Model = modelOpt.model
|
||||||
|
m.App.ProviderName = modelOpt.provider
|
||||||
|
}
|
||||||
|
return m, shared.ChangeView(m.prevView)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.modelList, cmd = m.modelList.Update(msg)
|
||||||
|
if cmd != nil {
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
m.content.SetContent(m.renderContent())
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) getModelOptions() []list.OptionGroup {
|
||||||
|
modelOpts := []list.OptionGroup{}
|
||||||
|
for _, p := range m.App.Ctx.Config.Providers {
|
||||||
|
provider := p.Name
|
||||||
|
if provider == "" {
|
||||||
|
provider = p.Kind
|
||||||
|
}
|
||||||
|
providerLabel := p.Display
|
||||||
|
if providerLabel == "" {
|
||||||
|
providerLabel = strings.ToUpper(provider[:1]) + provider[1:]
|
||||||
|
}
|
||||||
|
group := list.OptionGroup{
|
||||||
|
Name: providerLabel,
|
||||||
|
}
|
||||||
|
for _, model := range p.Models {
|
||||||
|
group.Options = append(group.Options, list.Option{
|
||||||
|
Label: model,
|
||||||
|
Value: modelOpt{provider, model},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
modelOpts = append(modelOpts, group)
|
||||||
|
}
|
||||||
|
return modelOpts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Header(width int) string {
|
||||||
|
boldStyle := lipgloss.NewStyle().Bold(true)
|
||||||
|
// TODO: update header depending on active settings mode (model, agent, etc)
|
||||||
|
header := boldStyle.Render("Model selection")
|
||||||
|
return styles.Header.Width(width).Render(header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Content(width, height int) string {
|
||||||
|
// TODO: see Header()
|
||||||
|
currentModel := " Active model: " + m.App.ActiveModel(lipgloss.NewStyle())
|
||||||
|
m.modelList.Width, m.modelList.Height = width, height - 2
|
||||||
|
return "\n" + currentModel + "\n" + m.modelList.View()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Footer(width int) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) renderContent() string {
|
||||||
|
return m.modelList.View()
|
||||||
|
}
|
@ -58,3 +58,29 @@ func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
|
|||||||
s.formatter.Format(&sb, s.style, it)
|
s.formatter.Format(&sb, s.style, it)
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ChromaHighlighter) HighlightLang(w io.Writer, text string, lang string) (error) {
|
||||||
|
l := lexers.Get(lang)
|
||||||
|
if l == nil {
|
||||||
|
l = lexers.Fallback
|
||||||
|
}
|
||||||
|
l = chroma.Coalesce(l)
|
||||||
|
old := s.lexer
|
||||||
|
s.lexer = l
|
||||||
|
err := s.Highlight(w, text)
|
||||||
|
s.lexer = old
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ChromaHighlighter) HighlightLangS(text string, lang string) (string, error) {
|
||||||
|
l := lexers.Get(lang)
|
||||||
|
if l == nil {
|
||||||
|
l = lexers.Fallback
|
||||||
|
}
|
||||||
|
l = chroma.Coalesce(l)
|
||||||
|
old := s.lexer
|
||||||
|
s.lexer = l
|
||||||
|
highlighted, err := s.HighlightS(text)
|
||||||
|
s.lexer = old
|
||||||
|
return highlighted, err
|
||||||
|
}
|
||||||
|
@ -21,7 +21,7 @@ func InputFromEditor(placeholder string, pattern string, content string) (string
|
|||||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||||
defer os.Remove(msgFile.Name())
|
defer os.Remove(msgFile.Name())
|
||||||
|
|
||||||
os.WriteFile(msgFile.Name(), []byte(placeholder + content), os.ModeAppend)
|
os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
||||||
|
|
||||||
editor := os.Getenv("EDITOR")
|
editor := os.Getenv("EDITOR")
|
||||||
if editor == "" {
|
if editor == "" {
|
||||||
@ -137,8 +137,8 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the "default" struct tag
|
// Get the "default" struct tag
|
||||||
defaultTag := v.Type().Field(i).Tag.Get("default")
|
defaultTag, ok := v.Type().Field(i).Tag.Lookup("default")
|
||||||
if defaultTag == "" {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,16 +147,16 @@ func SetStructDefaults(data interface{}) bool {
|
|||||||
case reflect.String:
|
case reflect.String:
|
||||||
defaultValue := defaultTag
|
defaultValue := defaultTag
|
||||||
field.Set(reflect.ValueOf(&defaultValue))
|
field.Set(reflect.ValueOf(&defaultValue))
|
||||||
|
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||||
|
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
|
||||||
|
field.Set(reflect.New(e))
|
||||||
|
field.Elem().SetUint(intValue)
|
||||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits())
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetInt(intValue)
|
field.Elem().SetInt(intValue)
|
||||||
case reflect.Float32:
|
case reflect.Float32, reflect.Float64:
|
||||||
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
|
||||||
field.Set(reflect.New(e))
|
|
||||||
field.Elem().SetFloat(floatValue)
|
|
||||||
case reflect.Float64:
|
|
||||||
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
|
||||||
field.Set(reflect.New(e))
|
field.Set(reflect.New(e))
|
||||||
field.Elem().SetFloat(floatValue)
|
field.Elem().SetFloat(floatValue)
|
||||||
case reflect.Bool:
|
case reflect.Bool:
|
||||||
|
Loading…
Reference in New Issue
Block a user