Compare commits
171 Commits
Author | SHA1 | Date |
---|---|---|
Matt Low | e59ce973b6 | |
Matt Low | 4ef841e945 | |
Matt Low | c68084f8a5 | |
Matt Low | 8ca044b6af | |
Matt Low | 6f5cf68208 | |
Matt Low | 914d9ac0c1 | |
Matt Low | 8ddac2f820 | |
Matt Low | cea5118cac | |
Matt Low | a43a91c6ff | |
Matt Low | ba7018af11 | |
Matt Low | f89cc7b410 | |
Matt Low | 677cfcfebf | |
Matt Low | 11402c5534 | |
Matt Low | a1fc8a637b | |
Matt Low | 94d84ba7d7 | |
Matt Low | c50b6b154d | |
Matt Low | 31df055430 | |
Matt Low | c30e652103 | |
Matt Low | 3fde58b77d | |
Matt Low | 85a2abbbf3 | |
Matt Low | dfe43179c0 | |
Matt Low | 42c3297e54 | |
Matt Low | a22119f738 | |
Matt Low | a2c860252f | |
Matt Low | d2d946b776 | |
Matt Low | c963747066 | |
Matt Low | e334d9fc4f | |
Matt Low | c1ead83939 | |
Matt Low | c9e92e186e | |
Matt Low | 45df957a06 | |
Matt Low | 136c463924 | |
Matt Low | 2580087b4d | |
Matt Low | 60a474d516 | |
Matt Low | ea576d24a6 | |
Matt Low | 465b1d333e | |
Matt Low | b29a4c8b84 | |
Matt Low | 58e1b84fea | |
Matt Low | a6522dbcd0 | |
Matt Low | 97cd047861 | |
Matt Low | ed784bb1cf | |
Matt Low | c1792f27ff | |
Matt Low | 0ad698a942 | |
Matt Low | 0d66a49997 | |
Matt Low | 008fdc0d37 | |
Matt Low | eec9eb41e9 | |
Matt Low | 437997872a | |
Matt Low | 3536438dd1 | |
Matt Low | f5ce970102 | |
Matt Low | 5c1248184b | |
Matt Low | 8c53752146 | |
Matt Low | f6e55f6bff | |
Matt Low | dc1edf8c3e | |
Matt Low | 62d98289e8 | |
Matt Low | b82f3019f0 | |
Matt Low | 1bd953676d | |
Matt Low | a291e7b42c | |
Matt Low | 1b8d04c96d | |
Matt Low | cbcd3b1ba9 | |
Matt Low | 75bf9f6125 | |
Matt Low | 9ff4322995 | |
Matt Low | 54f5a3c209 | |
Matt Low | 86bdc733bf | |
Matt Low | 60394de620 | |
Matt Low | aeeb7bb7f7 | |
Matt Low | 2b38db7db7 | |
Matt Low | 8e4ff90ab4 | |
Matt Low | bdaf6204f6 | |
Matt Low | 1b9a8f319c | |
Matt Low | ffe9d299ef | |
Matt Low | 08a2027332 | |
Matt Low | b06e031ee0 | |
Matt Low | 69d3265b64 | |
Matt Low | 7463b7502c | |
Matt Low | 0e68e22efa | |
Matt Low | 1404cae6a7 | |
Matt Low | 9e6d41a3ff | |
Matt Low | 39cd4227c6 | |
Matt Low | 105ee2e01b | |
Matt Low | e1970a315a | |
Matt Low | 020db40401 | |
Matt Low | 811ec4b251 | |
Matt Low | c68cb14eb9 | |
Matt Low | cef87a55d8 | |
Matt Low | 29519fa2f3 | |
Matt Low | 2e3779ad32 | |
Matt Low | 9cd28d28d7 | |
Matt Low | 0b991800d6 | |
Matt Low | 5af857edae | |
Matt Low | 3e24a54d0a | |
Matt Low | a669313a0b | |
Matt Low | 6310021dca | |
Matt Low | ef929da68c | |
Matt Low | c51644e78e | |
Matt Low | 91c74d9e1e | |
Matt Low | 3185b2d7d6 | |
Matt Low | 6c64f21d9a | |
Matt Low | 6f737ad19c | |
Matt Low | a8ffdc156a | |
Matt Low | 7a974d9764 | |
Matt Low | adb61ffa59 | |
Matt Low | 1c7ad75fd5 | |
Matt Low | 613aa1a552 | |
Matt Low | 71833b89cd | |
Matt Low | 2ad93394b1 | |
Matt Low | f49b772960 | |
Matt Low | 29d8138dc0 | |
Matt Low | 3756f6d9e4 | |
Matt Low | 41916eb7b3 | |
Matt Low | 3892e68251 | |
Matt Low | 8697284064 | |
Matt Low | 383d34f311 | |
Matt Low | ac0e380244 | |
Matt Low | c3a3cb0181 | |
Matt Low | 612ea90417 | |
Matt Low | 94508b1dbf | |
Matt Low | 7e002e5214 | |
Matt Low | 48e4dea3cf | |
Matt Low | 0ab552303d | |
Matt Low | 6ce42a77f9 | |
Matt Low | 2cb1a0005d | |
Matt Low | ea78edf039 | |
Matt Low | 793aaab50e | |
Matt Low | 5afc9667c7 | |
Matt Low | dfafc573e5 | |
Matt Low | 97f81a0cbb | |
Matt Low | eca120cde6 | |
Matt Low | 12d4e495d4 | |
Matt Low | d8c8262890 | |
Matt Low | 758f74aba5 | |
Matt Low | 1570c23d63 | |
Matt Low | 46149e0b67 | |
Matt Low | c2c61e2aaa | |
Matt Low | 5e880d3b31 | |
Matt Low | 62f07dd240 | |
Matt Low | ec1f326c2a | |
Matt Low | db116660a5 | |
Matt Low | 32eab7aa35 | |
Matt Low | 91d3c9c2e1 | |
Matt Low | 8bdb155bf7 | |
Matt Low | 045146bb5c | |
Matt Low | 2c7bdd8ebf | |
Matt Low | 7d56726c78 | |
Matt Low | f2c7d2bdd0 | |
Matt Low | 0a27b9a8d3 | |
Matt Low | 2611663168 | |
Matt Low | 120e61e88b | |
Matt Low | fa966d30db | |
Matt Low | 51ce74ad3a | |
Matt Low | b93ee94233 | |
Matt Low | db788760a3 | |
Matt Low | 242ed886ec | |
Matt Low | 02a23b9035 | |
Matt Low | b3913d0027 | |
Matt Low | 1184f9aaae | |
Matt Low | a25d0d95e8 | |
Matt Low | becaa5c7c0 | |
Matt Low | 239ded18f3 | |
Matt Low | 59e78669c8 | |
Matt Low | 1966ec881b | |
Matt Low | f6ded3e20e | |
Matt Low | 1e8ff60c54 | |
Matt Low | af2fccd4ee | |
Matt Low | f206334e72 | |
Matt Low | 5615051637 | |
Matt Low | c46500de4e | |
Matt Low | d5dde10dbf | |
Matt Low | d32e9421fe | |
Matt Low | e29dbaf2a3 | |
Matt Low | c64bc370f4 | |
Matt Low | 4f37ed046b | |
Matt Low | ed6ee9bea9 |
181
README.md
181
README.md
|
@ -1,37 +1,174 @@
|
|||
# lmcli
|
||||
# lmcli - Large ____ Model CLI
|
||||
|
||||
`lmcli` is a (Large) Language Model CLI.
|
||||
`lmcli` is a versatile command-line interface for interacting with LLMs and LMMs.
|
||||
|
||||
Current features:
|
||||
- Perform one-shot prompts with `lmcli prompt <message>`
|
||||
- Manage persistent conversations with the `new`, `reply`, `view`, and `rm`
|
||||
sub-commands.
|
||||
- Syntax highlighted output
|
||||
## Features
|
||||
|
||||
Planned features:
|
||||
- Ask questions about content received on stdin
|
||||
- "functions" to allow reading (and possibly writing) to files within the
|
||||
current working directory
|
||||
- Multiple model backends (Ollama, OpenAI, Anthropic, Google)
|
||||
- Customizable agents with tool calling
|
||||
- Persistent conversation management
|
||||
- Message branching (edit and re-prompt to your heart's desire)
|
||||
- Interactive terminal interface for seamless chat experiences
|
||||
- Utilizes `$EDITOR`, write and edit prompts from the comfort of your own editor :)
|
||||
- `vi`-like bindings
|
||||
- Syntax highlighting!
|
||||
|
||||
Maybe features:
|
||||
- Natural language image generation, iterative editing
|
||||
## Screenshots
|
||||
|
||||
## Install
|
||||
[TODO: Add screenshots of the TUI in action, showing different views and features]
|
||||
|
||||
```shell
|
||||
$ go install git.mlow.ca/mlow/lmcli@latest
|
||||
## Installation
|
||||
|
||||
To install `lmcli`, make sure you have Go installed on your system, then run:
|
||||
|
||||
```sh
|
||||
go install git.mlow.ca/mlow/lmcli@latest
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
`lmcli` uses a YAML configuration file located at `~/.config/lmcli/config.yaml`. Here's a sample configuration:
|
||||
|
||||
```yaml
|
||||
defaults:
|
||||
model: claude-3-5-sonnet-20240620
|
||||
maxTokens: 3072
|
||||
temperature: 0.2
|
||||
conversations:
|
||||
titleGenerationModel: claude-3-haiku-20240307
|
||||
chroma:
|
||||
style: onedark
|
||||
formatter: terminal16m
|
||||
agents:
|
||||
- name: coder
|
||||
tools:
|
||||
- dir_tree
|
||||
- read_file
|
||||
#- write_file
|
||||
systemPrompt: |
|
||||
You are an experienced software engineer...
|
||||
# ...
|
||||
providers:
|
||||
- kind: ollama
|
||||
models:
|
||||
- phi3:instruct
|
||||
- llama3:8b
|
||||
- kind: anthropic
|
||||
apiKey: your-api-key-here
|
||||
models:
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-opus-20240229
|
||||
- claude-3-haiku-20240307
|
||||
- kind: openai
|
||||
apiKey: your-api-key-here
|
||||
models:
|
||||
- gpt-4o
|
||||
- gpt-4-turbo
|
||||
- name: openrouter
|
||||
kind: openai
|
||||
apiKey: your-api-key-here
|
||||
baseUrl: https://openrouter.ai/api/
|
||||
models:
|
||||
- qwen/qwen-2-72b-instruct
|
||||
# ...
|
||||
```
|
||||
|
||||
Customize this file to add your own providers, agents, and models.
|
||||
|
||||
### Syntax highlighting
|
||||
|
||||
Syntax highlighting is performed by [Chroma](https://github.com/alecthomas/chroma).
|
||||
|
||||
Refer to [`Chroma/styles`](https://github.com/alecthomas/chroma/tree/master/styles) for available styles (TODO: add support for custom Chroma styles).
|
||||
|
||||
Available formatters:
|
||||
- `terminal` - 8 colors
|
||||
- `terminal16` - 16 colors
|
||||
- `terminal256` - 256 colors
|
||||
- `terminal16m` - true color (default)
|
||||
|
||||
## Agents
|
||||
|
||||
Agents in `lmcli` combine a system prompt with a set of available tools. Agents are defined in `config.yaml` and are called upon with the `-a`/`--agent` flag.
|
||||
|
||||
Agent functionality is expected to be expanded on, bringing them to close parity with something like OpenAI's "Assistants" feature.
|
||||
|
||||
## Tools
|
||||
|
||||
Tools are used by agents to acquire information from and interact with external systems. The following built-in tools are available:
|
||||
|
||||
- `dir_tree`: Display a directory structure
|
||||
- `read_file`: Read the contents of a file
|
||||
- `write_file`: Write content to a file
|
||||
- `file_insert_lines`: Insert lines at a specific position in a file
|
||||
- `file_replace_lines`: Replace a range of lines in a file
|
||||
|
||||
Obviously, some of these tools carry significant risk. Use wisely :)
|
||||
|
||||
More tool features are planned, including the ability to define arbitrary tools which call out to external scripts, tools to spawn sub-agents, perform web searches, etc.
|
||||
|
||||
## Usage
|
||||
|
||||
Invoke `lmcli` at least once:
|
||||
|
||||
```shell
|
||||
```console
|
||||
$ lmcli help
|
||||
lmcli - Large Language Model CLI
|
||||
|
||||
Usage:
|
||||
lmcli <command> [flags]
|
||||
lmcli [command]
|
||||
|
||||
Available Commands:
|
||||
chat Open the chat interface
|
||||
clone Clone conversations
|
||||
completion Generate the autocompletion script for the specified shell
|
||||
continue Continue a conversation from the last message
|
||||
edit Edit the last user reply in a conversation
|
||||
help Help about any command
|
||||
list List conversations
|
||||
new Start a new conversation
|
||||
prompt Do a one-shot prompt
|
||||
rename Rename a conversation
|
||||
reply Reply to a conversation
|
||||
retry Retry the last user reply in a conversation
|
||||
rm Remove conversations
|
||||
view View messages in a conversation
|
||||
|
||||
Flags:
|
||||
-h, --help help for lmcli
|
||||
|
||||
Use "lmcli [command] --help" for more information about a command.
|
||||
```
|
||||
|
||||
Edit `~/.config/lmcli/config.yaml` and set `openai.apiKey` to your API key.
|
||||
### Examples
|
||||
|
||||
Refer back to the output of `lmcli help` for usage.
|
||||
Start a new chat with the `coder` agent:
|
||||
|
||||
Enjoy!
|
||||
```console
|
||||
$ lmcli chat --agent coder
|
||||
```
|
||||
|
||||
Start a new conversation, imperative style (no tui):
|
||||
|
||||
```console
|
||||
$ lmcli new "Help me plan meals for the next week"
|
||||
```
|
||||
|
||||
Send a one-shot prompt (no persistence):
|
||||
|
||||
```console
|
||||
$ lmcli prompt "What is the answer to life, the universe, and everything?"
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to `lmcli` are welcome! Feel free to open issues or submit pull requests on the project repository.
|
||||
|
||||
For a full list of planned features and improvements, check out the [TODO.md](TODO.md) file.
|
||||
|
||||
## License
|
||||
|
||||
To be determined
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
`lmcli` is a small hobby project. Special thanks to the Go community and the creators of the libraries used in this project.
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
# TODO
|
||||
|
||||
- [x] Strip anthropic XML function call scheme from content, to reconstruct
|
||||
when calling anthropic?
|
||||
- [x] `dir_tree` tool
|
||||
- [x] Implement native Anthropic API tool calling
|
||||
- [x] Agents - a name given to a system prompt + set of available tools +
|
||||
potentially other relevent data (e.g. external service credentials, files for
|
||||
RAG, etc), which the user explicitly selects (e.g. `lmcli chat --agent
|
||||
code-helper`, `lmcli chat -a financier`).
|
||||
- [ ] Specialized agents which have integrations beyond basic tool calling,
|
||||
e.g. a coding agent which bakes in efficient code context management
|
||||
(only the current state of relevant files get shown to the model in the
|
||||
system prompt, rather than having them in the conversation messages)
|
||||
- [ ] Agents may have some form of long term memory management (key-value?
|
||||
natural lang?).
|
||||
- [ ] Sandboxed python, js interpreter (both useful for different reasons)
|
||||
- [ ] Support for arbitrary external script tools
|
||||
- [ ] Search - RAG driven search of existing conversation "hey, remind me of
|
||||
the conversation we had six months ago about X")
|
||||
- [ ] Conversation categorization - model driven category creation and
|
||||
conversation classification
|
||||
- [ ] Image input
|
||||
- [ ] Image output (sixel support?)
|
||||
- [ ] Conversation exports to html/pdf/json
|
||||
|
||||
## UI
|
||||
- [x] Prettify/normalize tool_call and tool_result outputs so they can be
|
||||
shown/optionally hidden in `lmcli view` and `lmcli chat`
|
||||
- [ ] User confirmation before calling (some?) tools
|
||||
- [ ] Conversation deletion in conversations view
|
||||
- [ ] Message deletion, Ctrl+D to delete a message and attach its children to
|
||||
its parent, Ctrl+Shift+D to delete a message and its descendents
|
||||
- [ ] Show available key bindings and their action in any given view
|
24
go.mod
24
go.mod
|
@ -4,25 +4,39 @@ go 1.21
|
|||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.11.1
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/gookit/color v1.5.4
|
||||
github.com/sashabaranov/go-openai v1.17.7
|
||||
github.com/charmbracelet/bubbles v0.18.0
|
||||
github.com/charmbracelet/bubbletea v0.25.0
|
||||
github.com/charmbracelet/lipgloss v0.10.0
|
||||
github.com/muesli/reflow v0.3.0
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/sqids/sqids-go v0.4.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/sqlite v1.5.4
|
||||
gorm.io/gorm v1.25.5
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.15.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.14.0 // indirect
|
||||
golang.org/x/term v0.6.0 // indirect
|
||||
golang.org/x/text v0.3.8 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||
)
|
||||
|
|
57
go.sum
57
go.sum
|
@ -4,16 +4,22 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
|
|||
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
|
||||
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
|
||||
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
|
||||
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
|
||||
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
|
||||
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
|
@ -26,33 +32,52 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
|||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
|
||||
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sashabaranov/go-openai v1.17.7 h1:MPcAwlwbeo7ZmhQczoOgZBHtIBY1TfZqsdx6+/ndloM=
|
||||
github.com/sashabaranov/go-openai v1.17.7/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
||||
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
|
||||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||
|
|
17
main.go
17
main.go
|
@ -1,15 +1,18 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cli.Execute(); err != nil {
|
||||
fmt.Fprint(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
ctx, err := lmcli.NewContext()
|
||||
if err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
root := cmd.RootCmd(ctx)
|
||||
if err := root.Execute(); err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const TREE_DESCRIPTION = `Retrieve a tree-like view of a directory's contents.
|
||||
|
||||
Use these results for your own reference in completing your task, they do not need to be shown to the user.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": ".
|
||||
├── a_directory/
|
||||
│ ├── file1.txt (100 bytes)
|
||||
│ └── file2.txt (200 bytes)
|
||||
├── a_file.txt (123 bytes)
|
||||
└── another_file.txt (456 bytes)"
|
||||
}
|
||||
`
|
||||
|
||||
var DirTreeTool = api.ToolSpec{
|
||||
Name: "dir_tree",
|
||||
Description: TREE_DESCRIPTION,
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "relative_path",
|
||||
Type: "string",
|
||||
Description: "If set, display the tree starting from this path relative to the current one.",
|
||||
},
|
||||
{
|
||||
Name: "depth",
|
||||
Type: "integer",
|
||||
Description: "Depth of directory recursion. Defaults to 0 (no recursion), maximum of 5.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
if tmp, ok := args["relative_path"]; ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("expected string for relative_path, got %T", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
var depth int = 0 // Default value if not provided
|
||||
if tmp, ok := args["depth"]; ok {
|
||||
switch v := tmp.(type) {
|
||||
case float64:
|
||||
depth = int(v)
|
||||
case string:
|
||||
var err error
|
||||
if depth, err = strconv.Atoi(v); err != nil {
|
||||
return "", fmt.Errorf("invalid `depth` value, expected integer but got string that cannot convert: %v", tmp)
|
||||
}
|
||||
depth = max(0, min(5, depth))
|
||||
default:
|
||||
return "", fmt.Errorf("expected int or string for max_depth, got %T", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := tree(relativeDir, depth)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func tree(path string, depth int) api.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
var treeOutput strings.Builder
|
||||
treeOutput.WriteString(path + "\n")
|
||||
err := buildTree(&treeOutput, path, "", depth)
|
||||
if err != nil {
|
||||
return api.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
return api.CallResult{Result: treeOutput.String()}
|
||||
}
|
||||
|
||||
func buildTree(output *strings.Builder, path string, prefix string, depth int) error {
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, file := range files {
|
||||
if strings.HasPrefix(file.Name(), ".") {
|
||||
// Skip hidden files and directories
|
||||
continue
|
||||
}
|
||||
|
||||
isLast := i == len(files)-1
|
||||
var branch string
|
||||
if isLast {
|
||||
branch = "└── "
|
||||
} else {
|
||||
branch = "├── "
|
||||
}
|
||||
|
||||
info, _ := file.Info()
|
||||
size := info.Size()
|
||||
sizeStr := fmt.Sprintf(" (%d bytes)", size)
|
||||
|
||||
output.WriteString(prefix + branch + file.Name())
|
||||
if file.IsDir() {
|
||||
output.WriteString("/\n")
|
||||
if depth > 0 {
|
||||
var nextPrefix string
|
||||
if isLast {
|
||||
nextPrefix = prefix + " "
|
||||
} else {
|
||||
nextPrefix = prefix + "│ "
|
||||
}
|
||||
buildTree(output, filepath.Join(path, file.Name()), nextPrefix, depth-1)
|
||||
}
|
||||
} else {
|
||||
output.WriteString(sizeStr + "\n")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,114 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
var FileInsertLinesTool = api.ToolSpec{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "position",
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileInsertLines(path, position, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileInsertLines(path string, position int, content string) api.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return api.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return api.CallResult{Result: newContent}
|
||||
}
|
|
@ -0,0 +1,178 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
var MODIFY_FILE_DESCRIPTION = []string{
|
||||
"Modify a file. If the file does not exist, it will be created.",
|
||||
"",
|
||||
"Content can be either inserted, replaced, or removed through a combination of the start, stop, and content parameters.",
|
||||
"Use the start and stop line numbers to limit the range of modification to the file.",
|
||||
"If both `start` and `stop` are left unset (or set to 0), the entire file's contents will be updated.",
|
||||
"If `start` is set to n and `stop` to n+1, content will be inserted at line n (the content that was at line n will be shifted below the newly inserted content).",
|
||||
"If only `start` is set, content from the given line and onwards will be updated.",
|
||||
"If only `stop` is set, content up to but not including the given line will be updated.",
|
||||
"",
|
||||
"Examples:",
|
||||
"1. Append to a file:",
|
||||
" {\"path\": \"example.txt\", \"start\": <last_line_number + 1>, \"content\": \"New content to append\"}",
|
||||
"",
|
||||
"2. Insert at a specific line:",
|
||||
" {\"path\": \"example.txt\", \"start\": 5, \"stop\": 5, \"content\": \"New line inserted above the previous line 5\"}",
|
||||
"",
|
||||
"3. Replace a range of lines:",
|
||||
" {\"path\": \"example.txt\", \"start\": 3, \"stop\": 7, \"content\": \"New content replacing lines 3-7\"}",
|
||||
"",
|
||||
"4. Remove a range of lines:",
|
||||
" {\"path\": \"example.txt\", \"start\": 2, \"stop\": 5}",
|
||||
"",
|
||||
"5. Replace entire file contents:",
|
||||
" {\"path\": \"example.txt\", \"content\": \"New file contents\"}",
|
||||
"",
|
||||
"6. Update from a specific line to the end of the file:",
|
||||
" {\"path\": \"example.txt\", \"start\": 10, \"content\": \"New content from line 10 onwards\"}",
|
||||
"",
|
||||
"7. Update from the beginning of the file to a specific line:",
|
||||
" {\"path\": \"example.txt\", \"stop\": 6, \"content\": \"New content for first 5 lines\"}",
|
||||
"",
|
||||
"Note: Always use specific line numbers based on the current file content. Avoid using arbitrarily large numbers for start or stop.",
|
||||
}
|
||||
|
||||
var ModifyFile = api.ToolSpec{
|
||||
Name: "modify_file",
|
||||
Description: strings.Join(MODIFY_FILE_DESCRIPTION, "\n"),
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "start",
|
||||
Type: "integer",
|
||||
Description: `Start line of the range to modify (inclusive). If omitted, the beginning of the file is implied.`,
|
||||
},
|
||||
{
|
||||
Name: "stop",
|
||||
Type: "integer",
|
||||
Description: `End line of the range to modify (inclusive). If omitted, the end of the file is implied.`,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: "Content to insert/replace at the range defined by `start` and `stop`. If omitted, the range is removed.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to modify_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start int
|
||||
tmp, ok = args["start"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start in function arguments: %v", tmp)
|
||||
}
|
||||
start = int(tmp)
|
||||
}
|
||||
var stop int
|
||||
tmp, ok = args["stop"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid stop in function arguments: %v", tmp)
|
||||
}
|
||||
stop = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileModifyContents(path, start, stop, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
func fileModifyContents(path string, startLine int, stopLine int, content string) api.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
|
||||
|
||||
// If both start and stop are unset, update the entire file
|
||||
if startLine == 0 && stopLine == 0 {
|
||||
lines = contentLines
|
||||
} else {
|
||||
if startLine < 1 {
|
||||
startLine = 1
|
||||
}
|
||||
if stopLine == 0 || stopLine > len(lines) {
|
||||
stopLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[stopLine:]
|
||||
|
||||
// Handle insertion case
|
||||
if startLine == stopLine {
|
||||
lines = append(before, append(contentLines, lines[startLine-1:]...)...)
|
||||
} else {
|
||||
// If content is omitted, remove the specified range
|
||||
if content == "" {
|
||||
lines = append(before, after...)
|
||||
} else {
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return api.CallResult{Result: util.AddLineNumbers(newContent)}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": [
|
||||
{"name": "a_file.txt", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size of the file, in bytes.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
var ReadDirTool = api.ToolSpec{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "relative_dir",
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
result := readDir(relativeDir)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readDir(path string) api.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return api.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return api.CallResult{Result: dirContents}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const READ_FILE_DESCRIPTION = `Retrieve the contents of a text file relative to the current working directory.
|
||||
|
||||
Use the file contents for your own reference in completing your task, they do not need to be shown to the user.
|
||||
|
||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
var ReadFileTool = api.ToolSpec{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
|
||||
Impl: func(tool *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
result := readFile(path)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readFile(path string) api.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
return api.CallResult{
|
||||
Result: toolutil.AddLineNumbers(string(data)),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func IsPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func IsPathWithinCWD(path string) (bool, string) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, "Failed to determine current working directory"
|
||||
}
|
||||
if ok, err := IsPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
|
||||
}
|
||||
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
|
||||
// AddLineNumbers takes a string of content and returns a new string with line
|
||||
// numbers prefixed
|
||||
func AddLineNumbers(content string) string {
|
||||
lines := strings.Split(strings.TrimSuffix(content, "\n"), "\n")
|
||||
result := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
result.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
return result.String()
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
package toolbox
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/agents/toolbox/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success"
|
||||
}`
|
||||
|
||||
var WriteFileTool = api.ToolSpec{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: []api.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(t *api.ToolSpec, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
result := writeFile(path, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func writeFile(path string, content string) api.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return api.CallResult{Message: reason}
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return api.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
return api.CallResult{}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package agents
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agents/toolbox"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
var AvailableTools map[string]api.ToolSpec = map[string]api.ToolSpec{
|
||||
"dir_tree": toolbox.DirTreeTool,
|
||||
"read_dir": toolbox.ReadDirTool,
|
||||
"read_file": toolbox.ReadFileTool,
|
||||
"modify_file": toolbox.ModifyFile,
|
||||
"write_file": toolbox.WriteFileTool,
|
||||
}
|
||||
|
||||
func ExecuteToolCalls(calls []api.ToolCall, available []api.ToolSpec) ([]api.ToolResult, error) {
|
||||
var toolResults []api.ToolResult
|
||||
for _, call := range calls {
|
||||
var tool *api.ToolSpec
|
||||
for i := range available {
|
||||
if available[i].Name == call.Name {
|
||||
tool = &available[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if tool == nil {
|
||||
return nil, fmt.Errorf("Requested tool '%s' is not available. Hallucination?", call.Name)
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
result, err := tool.Impl(tool, call.Parameters)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", call.Name, err)
|
||||
}
|
||||
|
||||
toolResult := api.ToolResult{
|
||||
ToolCallID: call.ID,
|
||||
ToolName: call.Name,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ReplyCallback func(Message)
|
||||
|
||||
type Chunk struct {
|
||||
Content string
|
||||
TokenCount uint
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
Model string
|
||||
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
Toolbox []ToolSpec
|
||||
}
|
||||
|
||||
type ChatCompletionProvider interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
// complete user-facing response is returned as a string.
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []Message,
|
||||
) (*Message, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params RequestParameters,
|
||||
messages []Message,
|
||||
chunks chan<- Chunk,
|
||||
) (*Message, error)
|
||||
}
|
||||
|
||||
func IsAssistantContinuation(messages []Message) bool {
|
||||
if len(messages) == 0 {
|
||||
return false
|
||||
}
|
||||
return messages[len(messages)-1].Role == MessageRoleAssistant
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package api
|
||||
|
||||
import "database/sql"
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
SelectedRootID *uint
|
||||
SelectedRoot *Message `gorm:"foreignKey:SelectedRootID"`
|
||||
}
|
|
@ -0,0 +1,72 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID *uint `gorm:"index"`
|
||||
Conversation *Conversation `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the model)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
ParentID *uint
|
||||
Parent *Message `gorm:"foreignKey:ParentID"`
|
||||
Replies []Message `gorm:"foreignKey:ParentID"`
|
||||
|
||||
SelectedReplyID *uint
|
||||
SelectedReply *Message `gorm:"foreignKey:SelectedReplyID"`
|
||||
}
|
||||
|
||||
func ApplySystemPrompt(m []Message, system string, force bool) []Message {
|
||||
if len(m) > 0 && m[0].Role == MessageRoleSystem {
|
||||
if force {
|
||||
m[0].Content = system
|
||||
}
|
||||
return m
|
||||
} else {
|
||||
return append([]Message{{
|
||||
Role: MessageRoleSystem,
|
||||
Content: system,
|
||||
}}, m...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MessageRole) IsAssistant() bool {
|
||||
switch *m {
|
||||
case MessageRoleAssistant, MessageRoleToolCall:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m MessageRole) FriendlyRole() string {
|
||||
switch m {
|
||||
case MessageRoleUser:
|
||||
return "You"
|
||||
case MessageRoleSystem:
|
||||
return "System"
|
||||
case MessageRoleAssistant:
|
||||
return "Assistant"
|
||||
case MessageRoleToolCall:
|
||||
return "Tool Call"
|
||||
case MessageRoleToolResult:
|
||||
return "Tool Result"
|
||||
default:
|
||||
return string(m)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,450 @@
|
|||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
const ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
type AnthropicClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
InputSchema InputSchema `json:"input_schema"`
|
||||
}
|
||||
|
||||
type InputSchema struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]Property `json:"properties"`
|
||||
Required []string `json:"required"`
|
||||
}
|
||||
|
||||
type Property struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ContentBlock struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Input interface{} `json:"input,omitempty"`
|
||||
partialJsonAccumulator string
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Model string `json:"model"`
|
||||
Content []ContentBlock `json:"content"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type StreamEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message interface{} `json:"message,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
Delta interface{} `json:"delta,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
anthropicTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
properties := make(map[string]Property)
|
||||
for _, param := range tool.Parameters {
|
||||
properties[param.Name] = Property{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
}
|
||||
|
||||
var required []string
|
||||
for _, param := range tool.Parameters {
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
anthropicTools[i] = Tool{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
InputSchema: InputSchema{
|
||||
Type: "object",
|
||||
Properties: properties,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return anthropicTools
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (string, ChatCompletionRequest) {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
var systemMessage string
|
||||
|
||||
for _, m := range messages {
|
||||
if m.Role == api.MessageRoleSystem {
|
||||
systemMessage = m.Content
|
||||
continue
|
||||
}
|
||||
|
||||
var content interface{}
|
||||
role := string(m.Role)
|
||||
|
||||
switch m.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
role = "assistant"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
if m.Content != "" {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": m.Content,
|
||||
})
|
||||
}
|
||||
for _, toolCall := range m.ToolCalls {
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCall.ID,
|
||||
"name": toolCall.Name,
|
||||
"input": toolCall.Parameters,
|
||||
})
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
case api.MessageRoleToolResult:
|
||||
role = "user"
|
||||
contentBlocks := make([]map[string]interface{}, 0)
|
||||
for _, result := range m.ToolResults {
|
||||
contentBlock := map[string]interface{}{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": result.ToolCallID,
|
||||
"content": result.Result,
|
||||
}
|
||||
contentBlocks = append(contentBlocks, contentBlock)
|
||||
}
|
||||
content = contentBlocks
|
||||
|
||||
default:
|
||||
content = m.Content
|
||||
}
|
||||
|
||||
requestMessages = append(requestMessages, ChatCompletionMessage{
|
||||
Role: role,
|
||||
Content: content,
|
||||
})
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
System: systemMessage,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
}
|
||||
|
||||
var prefill string
|
||||
if api.IsAssistantContinuation(messages) {
|
||||
prefill = messages[len(messages)-1].Content
|
||||
}
|
||||
|
||||
return prefill, request
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||
jsonData, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/messages", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", ANTHROPIC_VERSION)
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
_, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
return convertResponseToMessage(completionResp)
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("can't create completion from no messages")
|
||||
}
|
||||
|
||||
prefill, req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
contentBlocks := make(map[int]*ContentBlock)
|
||||
var finalMessage *ChatCompletionResponse
|
||||
|
||||
var firstChunkReceived bool
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var streamEvent StreamEvent
|
||||
err = json.Unmarshal(line, &streamEvent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal stream event: %w", err)
|
||||
}
|
||||
|
||||
switch streamEvent.Type {
|
||||
case "message_start":
|
||||
finalMessage = &ChatCompletionResponse{}
|
||||
err = json.Unmarshal(line, &struct {
|
||||
Message *ChatCompletionResponse `json:"message"`
|
||||
}{Message: finalMessage})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal message_start: %w", err)
|
||||
}
|
||||
case "content_block_start":
|
||||
var contentBlockStart struct {
|
||||
Index int `json:"index"`
|
||||
ContentBlock ContentBlock `json:"content_block"`
|
||||
}
|
||||
err = json.Unmarshal(line, &contentBlockStart)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal content_block_start: %w", err)
|
||||
}
|
||||
|
||||
contentBlocks[contentBlockStart.Index] = &contentBlockStart.ContentBlock
|
||||
case "content_block_delta":
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received delta for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected delta type: %T", streamEvent.Delta)
|
||||
}
|
||||
|
||||
deltaType, ok := delta["type"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("delta missing type field")
|
||||
}
|
||||
|
||||
switch deltaType {
|
||||
case "text_delta":
|
||||
if text, ok := delta["text"].(string); ok {
|
||||
if !firstChunkReceived {
|
||||
if prefill == "" {
|
||||
// if there is no prefil, ensure we trim leading whitespace
|
||||
text = strings.TrimSpace(text)
|
||||
}
|
||||
firstChunkReceived = true
|
||||
}
|
||||
block.Text += text
|
||||
output <- api.Chunk{
|
||||
Content: text,
|
||||
TokenCount: 1,
|
||||
}
|
||||
}
|
||||
case "input_json_delta":
|
||||
if block.Type != "tool_use" {
|
||||
return nil, fmt.Errorf("received input_json_delta for non-tool_use block")
|
||||
}
|
||||
if partialJSON, ok := delta["partial_json"].(string); ok {
|
||||
block.partialJsonAccumulator += partialJSON
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
if streamEvent.Index >= len(contentBlocks) {
|
||||
return nil, fmt.Errorf("received stop for non-existent content block index: %d", streamEvent.Index)
|
||||
}
|
||||
|
||||
block := contentBlocks[streamEvent.Index]
|
||||
if block.Type == "tool_use" && block.partialJsonAccumulator != "" {
|
||||
var inputData map[string]interface{}
|
||||
err := json.Unmarshal([]byte(block.partialJsonAccumulator), &inputData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal accumulated JSON for tool use: %w", err)
|
||||
}
|
||||
block.Input = inputData
|
||||
}
|
||||
case "message_delta":
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("received message_delta before message_start")
|
||||
}
|
||||
delta, ok := streamEvent.Delta.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected delta type in message_delta: %T", streamEvent.Delta)
|
||||
}
|
||||
if stopReason, ok := delta["stop_reason"].(string); ok {
|
||||
finalMessage.StopReason = stopReason
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
// End of the stream
|
||||
goto END_STREAM
|
||||
|
||||
case "error":
|
||||
return nil, fmt.Errorf("received error event: %v", streamEvent.Message)
|
||||
|
||||
default:
|
||||
// Ignore unknown event types
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
END_STREAM:
|
||||
if finalMessage == nil {
|
||||
return nil, fmt.Errorf("no final message received")
|
||||
}
|
||||
|
||||
finalMessage.Content = make([]ContentBlock, len(contentBlocks))
|
||||
for _, v := range contentBlocks {
|
||||
finalMessage.Content = append(finalMessage.Content, *v)
|
||||
}
|
||||
|
||||
return convertResponseToMessage(*finalMessage)
|
||||
}
|
||||
|
||||
func convertResponseToMessage(resp ChatCompletionResponse) (*api.Message, error) {
|
||||
content := strings.Builder{}
|
||||
var toolCalls []api.ToolCall
|
||||
|
||||
for _, block := range resp.Content {
|
||||
switch block.Type {
|
||||
case "text":
|
||||
content.WriteString(block.Text)
|
||||
case "tool_use":
|
||||
parameters, ok := block.Input.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type for tool call parameters: %T", block.Input)
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
ID: block.ID,
|
||||
Name: block.Name,
|
||||
Parameters: parameters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
message := &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
ToolCalls: toolCalls,
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
message.Role = api.MessageRoleToolCall
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
|
@ -0,0 +1,450 @@
|
|||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ContentPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||
FunctionResp *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]string `json:"args"`
|
||||
}
|
||||
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response interface{} `json:"response"`
|
||||
}
|
||||
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []ContentPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GenerationConfig struct {
|
||||
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
||||
Temperature *float32 `json:"temperature,omitempty"`
|
||||
TopP *float32 `json:"topP,omitempty"`
|
||||
TopK *int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
GenerationConfig *GenerationConfig `json:"generationConfig,omitempty"`
|
||||
}
|
||||
|
||||
type Candidate struct {
|
||||
Content Content `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
}
|
||||
|
||||
type GenerateContentResponse struct {
|
||||
Candidates []Candidate `json:"candidates"`
|
||||
UsageMetadata UsageMetadata `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
type FunctionDeclaration struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Values []string `json:"values,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
geminiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
// TODO: proper enum handing
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Values: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
geminiTools[i] = Tool{
|
||||
FunctionDeclarations: []FunctionDeclaration{
|
||||
{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "OBJECT",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return geminiTools
|
||||
}
|
||||
|
||||
func convertToolCallToGemini(toolCalls []api.ToolCall) []ContentPart {
|
||||
converted := make([]ContentPart, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
args := make(map[string]string)
|
||||
for k, v := range call.Parameters {
|
||||
args[k] = fmt.Sprintf("%v", v)
|
||||
}
|
||||
converted[i].FunctionCall = &FunctionCall{
|
||||
Name: call.Name,
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(functionCalls []FunctionCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(functionCalls))
|
||||
for i, call := range functionCalls {
|
||||
params := make(map[string]interface{})
|
||||
for k, v := range call.Args {
|
||||
params[k] = v
|
||||
}
|
||||
converted[i].Name = call.Name
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolResultsToGemini(toolResults []api.ToolResult) ([]FunctionResponse, error) {
|
||||
results := make([]FunctionResponse, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &obj)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal %s: %v", result.Result, err)
|
||||
}
|
||||
results[i] = FunctionResponse{
|
||||
Name: result.ToolName,
|
||||
Response: obj,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func createGenerateContentRequest(
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*GenerateContentRequest, error) {
|
||||
requestContents := make([]Content, 0, len(messages))
|
||||
|
||||
startIdx := 0
|
||||
var system string
|
||||
if len(messages) > 0 && messages[0].Role == api.MessageRoleSystem {
|
||||
system = messages[0].Content
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
for _, m := range messages[startIdx:] {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
content := Content{
|
||||
Role: "model",
|
||||
Parts: convertToolCallToGemini(m.ToolCalls),
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
case "tool_result":
|
||||
results, err := convertToolResultsToGemini(m.ToolResults)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// expand tool_result messages' results into multiple gemini messages
|
||||
for _, result := range results {
|
||||
content := Content{
|
||||
Role: "function",
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
FunctionResp: &result,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
default:
|
||||
var role string
|
||||
switch m.Role {
|
||||
case api.MessageRoleAssistant:
|
||||
role = "model"
|
||||
case api.MessageRoleUser:
|
||||
role = "user"
|
||||
}
|
||||
|
||||
if role == "" {
|
||||
panic("Unhandled role: " + m.Role)
|
||||
}
|
||||
|
||||
content := Content{
|
||||
Role: role,
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: m.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
requestContents = append(requestContents, content)
|
||||
}
|
||||
}
|
||||
|
||||
request := &GenerateContentRequest{
|
||||
Contents: requestContents,
|
||||
GenerationConfig: &GenerationConfig{
|
||||
MaxOutputTokens: ¶ms.MaxTokens,
|
||||
Temperature: ¶ms.Temperature,
|
||||
TopP: ¶ms.TopP,
|
||||
},
|
||||
}
|
||||
|
||||
if system != "" {
|
||||
request.SystemInstruction = &Content{
|
||||
Parts: []ContentPart{
|
||||
{
|
||||
Text: system,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:generateContent?key=%s",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp GenerateContentResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Candidates[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.Text != "" {
|
||||
content += part.Text
|
||||
}
|
||||
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req, err := createGenerateContentRequest(params, messages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(
|
||||
"%s/v1beta/models/%s:streamGenerateContent?key=%s&alt=sse",
|
||||
c.BaseURL, params.Model, c.APIKey,
|
||||
)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
var toolCalls []FunctionCall
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
|
||||
lastTokenCount := 0
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
|
||||
var resp GenerateContentResponse
|
||||
err = json.Unmarshal(line, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokens := resp.UsageMetadata.CandidatesTokenCount - lastTokenCount
|
||||
lastTokenCount += tokens
|
||||
|
||||
choice := resp.Candidates[0]
|
||||
for _, part := range choice.Content.Parts {
|
||||
if part.FunctionCall != nil {
|
||||
toolCalls = append(toolCalls, *part.FunctionCall)
|
||||
} else if part.Text != "" {
|
||||
output <- api.Chunk{
|
||||
Content: part.Text,
|
||||
TokenCount: uint(tokens),
|
||||
}
|
||||
content.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are function calls, handle them and recurse
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
package ollama
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
type OllamaClient struct {
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type OllamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type OllamaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []OllamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type OllamaResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message OllamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
TotalDuration uint64 `json:"total_duration,omitempty"`
|
||||
LoadDuration uint64 `json:"load_duration,omitempty"`
|
||||
PromptEvalCount uint64 `json:"prompt_eval_count,omitempty"`
|
||||
PromptEvalDuration uint64 `json:"prompt_eval_duration,omitempty"`
|
||||
EvalCount uint64 `json:"eval_count,omitempty"`
|
||||
EvalDuration uint64 `json:"eval_duration,omitempty"`
|
||||
}
|
||||
|
||||
func createOllamaRequest(
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) OllamaRequest {
|
||||
requestMessages := make([]OllamaMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
message := OllamaMessage{
|
||||
Role: string(m.Role),
|
||||
Content: m.Content,
|
||||
}
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
|
||||
request := OllamaRequest{
|
||||
Model: params.Model,
|
||||
Messages: requestMessages,
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OllamaClient) sendRequest(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = false
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp OllamaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: completionResp.Message.Content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OllamaClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createOllamaRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/chat", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.sendRequest(httpReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
var streamResp OllamaResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(streamResp.Message.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
Content: streamResp.Message.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(streamResp.Message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
Type string `json:"type"`
|
||||
ID string `json:"id"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type FunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolParameters `json:"parameters"`
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]ToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function FunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
Messages []ChatCompletionMessage `json:"messages"`
|
||||
N int `json:"n"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
ToolChoice string `json:"tool_choice,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChoice struct {
|
||||
Message ChatCompletionMessage `json:"message"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
Choices []ChatCompletionChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamChoice struct {
|
||||
Delta ChatCompletionMessage `json:"delta"`
|
||||
}
|
||||
|
||||
type ChatCompletionStreamResponse struct {
|
||||
Choices []ChatCompletionStreamChoice `json:"choices"`
|
||||
}
|
||||
|
||||
func convertTools(tools []api.ToolSpec) []Tool {
|
||||
openaiTools := make([]Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
|
||||
params := make(map[string]ToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = ToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
openaiTools[i].Function = FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: ToolParameters{
|
||||
Type: "object",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []api.ToolCall) []ToolCall {
|
||||
converted := make([]ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Function.Name = call.Name
|
||||
|
||||
json, _ := json.Marshal(call.Parameters)
|
||||
converted[i].Function.Arguments = string(json)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []ToolCall) []api.ToolCall {
|
||||
converted := make([]api.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) ChatCompletionRequest {
|
||||
requestMessages := make([]ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "assistant"
|
||||
message.Content = m.Content
|
||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
requestMessages = append(requestMessages, message)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple openAI messages
|
||||
for _, result := range m.ToolResults {
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = "tool"
|
||||
message.Content = result.Result
|
||||
message.ToolCallID = result.ToolCallID
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
default:
|
||||
message := ChatCompletionMessage{}
|
||||
message.Role = string(m.Role)
|
||||
message.Content = m.Content
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
request := ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Messages: requestMessages,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
if len(params.Toolbox) > 0 {
|
||||
request.Tools = convertTools(params.Toolbox)
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendRequest(ctx context.Context, r ChatCompletionRequest) (*http.Response, error) {
|
||||
jsonData, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
bytes, _ := io.ReadAll(resp.Body)
|
||||
return resp, fmt.Errorf("%v", string(bytes))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var completionResp ChatCompletionResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&completionResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
choice := completionResp.Choices[0]
|
||||
|
||||
var content string
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content = lastMessage.Content + choice.Message.Content
|
||||
} else {
|
||||
content = choice.Message.Content
|
||||
}
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params api.RequestParameters,
|
||||
messages []api.Message,
|
||||
output chan<- api.Chunk,
|
||||
) (*api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, fmt.Errorf("Can't create completion from no messages")
|
||||
}
|
||||
|
||||
req := createChatCompletionRequest(params, messages)
|
||||
req.Stream = true
|
||||
|
||||
resp, err := c.sendRequest(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []ToolCall{}
|
||||
|
||||
lastMessage := messages[len(messages)-1]
|
||||
if lastMessage.Role.IsAssistant() {
|
||||
content.WriteString(lastMessage.Content)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data: ")) {
|
||||
continue
|
||||
}
|
||||
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(line, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
var streamResp ChatCompletionStreamResponse
|
||||
err = json.Unmarshal(line, &streamResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delta := streamResp.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return nil, fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(delta.Content) > 0 {
|
||||
output <- api.Chunk{
|
||||
Content: delta.Content,
|
||||
TokenCount: 1,
|
||||
}
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleToolCall,
|
||||
Content: content.String(),
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
}, nil
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type ToolSpec struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*ToolSpec, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id" yaml:"-"`
|
||||
Name string `json:"name" yaml:"tool"`
|
||||
Parameters map[string]interface{} `json:"parameters" yaml:"parameters"`
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID" yaml:"-"`
|
||||
ToolName string `json:"toolName,omitempty" yaml:"tool"`
|
||||
Result string `json:"result,omitempty" yaml:"result"`
|
||||
}
|
||||
|
||||
type ToolCalls []ToolCall
|
||||
|
||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tc = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tc)
|
||||
return
|
||||
}
|
||||
|
||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||
if len(tc) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(tc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (tr *ToolResults) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tr = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (tr ToolResults) Value() (driver.Value, error) {
|
||||
if len(tr) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var config *Config
|
||||
var store *Store
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
config, err = NewConfig()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
store, err = NewStore()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
550
pkg/cli/cmd.go
550
pkg/cli/cmd.go
|
@ -1,550 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
maxTokens int
|
||||
model string
|
||||
systemPrompt string
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
func init() {
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
|
||||
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
|
||||
cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "The system prompt to use.")
|
||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file whose contents are used as the system prompt.")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(
|
||||
continueCmd,
|
||||
lsCmd,
|
||||
newCmd,
|
||||
promptCmd,
|
||||
replyCmd,
|
||||
retryCmd,
|
||||
rmCmd,
|
||||
viewCmd,
|
||||
)
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func SystemPrompt() string {
|
||||
if systemPromptFile != "" {
|
||||
content, err := FileContents(systemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v", systemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
||||
// response to stdout, and returns the (partial) response or any errors.
|
||||
func LLMRequest(messages []Message) (string, error) {
|
||||
// receiver receives the reponse from LLM
|
||||
receiver := make(chan string)
|
||||
defer close(receiver)
|
||||
|
||||
// start HandleDelayedContent goroutine to print received data to stdout
|
||||
go HandleDelayedContent(receiver)
|
||||
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if response != "" {
|
||||
if err != nil {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func InputFromArgsOrEditor(args []string, placeholder string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = InputFromEditor(placeholder, "message.*.md")
|
||||
if err != nil {
|
||||
Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Trim(strings.Join(args, " "), " \t\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "lmcli",
|
||||
Short: "Interact with Large Language Models",
|
||||
Long: `lmcli is a CLI tool to interact with Large Language Models.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// execute `lm ls` by default
|
||||
},
|
||||
}
|
||||
|
||||
var lsCmd = &cobra.Command{
|
||||
Use: "ls",
|
||||
Short: "List existing conversations",
|
||||
Long: `List all existing conversations in descending order of recent activity.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
conversations, err := store.Conversations()
|
||||
if err != nil {
|
||||
fmt.Println("Could not fetch conversations.")
|
||||
return
|
||||
}
|
||||
|
||||
// Example output
|
||||
// $ lmcli ls
|
||||
// last hour:
|
||||
// 98sg - 12 minutes ago - Project discussion
|
||||
// last day:
|
||||
// tj3l - 10 hours ago - Deep learning concepts
|
||||
// last week:
|
||||
// bwfm - 2 days ago - Machine learning study
|
||||
// 8n3h - 3 days ago - Weekend plans
|
||||
// f3n7 - 6 days ago - CLI development
|
||||
// last month:
|
||||
// 5hn2 - 8 days ago - Book club discussion
|
||||
// b7ze - 20 days ago - Gardening tips and tricks
|
||||
// last 6 months:
|
||||
// 3jn2 - 30 days ago - Web development best practices
|
||||
// 43jk - 2 months ago - Longboard maintenance
|
||||
// g8d9 - 3 months ago - History book club
|
||||
// 4lk3 - 4 months ago - Local events and meetups
|
||||
// 43jn - 6 months ago - Mobile photography techniques
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
categories := []string{
|
||||
"recent",
|
||||
"last hour",
|
||||
"last 6 hours",
|
||||
"last day",
|
||||
"last week",
|
||||
"last month",
|
||||
"last 6 months",
|
||||
"older",
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
switch {
|
||||
case messageAge <= 10*time.Minute:
|
||||
category = "recent"
|
||||
case messageAge <= time.Hour:
|
||||
category = "last hour"
|
||||
case messageAge <= 6*time.Hour:
|
||||
category = "last 6 hours"
|
||||
case messageAge <= 24*time.Hour:
|
||||
category = "last day"
|
||||
case messageAge <= 7*24*time.Hour:
|
||||
category = "last week"
|
||||
case messageAge <= 30*24*time.Hour:
|
||||
category = "last month"
|
||||
case messageAge <= 6*30*24*time.Hour: // Approximate as 6 months
|
||||
category = "last 6 months"
|
||||
default:
|
||||
category = "older"
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
conversation.ShortName.String,
|
||||
humanTimeElapsedSince(messageAge),
|
||||
conversation.Title,
|
||||
)
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
for _, category := range categories {
|
||||
conversations, ok := categorized[category]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(conversations, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
fmt.Printf("%s:\n", category)
|
||||
for _, conv := range conversations {
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var rmCmd = &cobra.Command{
|
||||
Use: "rm <conversation>",
|
||||
Short: "Remove a conversation",
|
||||
Long: `Removes a conversation by its short name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
Fatal("Could not search for conversation: %v\n", err)
|
||||
}
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
err = store.DeleteConversation(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not delete conversation: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var viewCmd = &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
RenderConversation(messages, false)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var replyCmd = &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Send a reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
messageContents := InputFromArgsOrEditor(args[1:], "# How would you like to reply?\n")
|
||||
if messageContents == "" {
|
||||
Fatal("No reply was provided.\n")
|
||||
}
|
||||
|
||||
userReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
}
|
||||
|
||||
err = store.SaveMessage(&userReply)
|
||||
if err != nil {
|
||||
Warn("Could not save your reply: %v\n", err)
|
||||
}
|
||||
|
||||
messages = append(messages, userReply)
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var newCmd = &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
messageContents := InputFromArgsOrEditor(args, "# What would you like to say?\n")
|
||||
if messageContents == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
conversation := Conversation{}
|
||||
err := store.SaveConversation(&conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not save new conversation: %v\n", err)
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "system",
|
||||
OriginalContent: SystemPrompt(),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
},
|
||||
}
|
||||
for _, message := range messages {
|
||||
err = store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
reply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
reply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
reply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Fatal("Could not save reply: %v\n", err)
|
||||
}
|
||||
|
||||
err = conversation.GenerateTitle()
|
||||
if err != nil {
|
||||
Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
err = store.SaveConversation(&conversation)
|
||||
if err != nil {
|
||||
Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var promptCmd = &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
message := InputFromArgsOrEditor(args, "# What would you like to say?\n")
|
||||
if message == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
OriginalContent: SystemPrompt(),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var retryCmd = &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retries the last conversation prompt.",
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
var lastUserMessageIndex int
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
lastUserMessageIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
messages = messages[:lastUserMessageIndex+1]
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var continueCmd = &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continues where the previous prompt left off.",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-yaml/yaml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ModelDefaults *struct {
|
||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||
} `yaml:"modelDefaults"`
|
||||
OpenAI *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
||||
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
||||
} `yaml:"openai"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
}
|
||||
|
||||
func ConfigDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func NewConfig() (*Config, error) {
|
||||
configFile := filepath.Join(ConfigDir(), "config.yaml")
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
configBytes, err := os.ReadFile(configFile)
|
||||
if os.IsNotExist(err) {
|
||||
shouldWriteDefaults = true
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||
} else {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
}
|
||||
|
||||
shouldWriteDefaults = SetStructDefaults(c)
|
||||
if shouldWriteDefaults {
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
}
|
||||
bytes, _ := yaml.Marshal(c)
|
||||
_, err = file.Write(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *Message) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch m.Role {
|
||||
case "user":
|
||||
friendlyRole = "You"
|
||||
case "system":
|
||||
friendlyRole = "System"
|
||||
case "assistant":
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = m.Role
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
|
||||
func (c *Conversation) GenerateTitle() error {
|
||||
const header = "Generate a consise 4-5 word title for the conversation below."
|
||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting())
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||
response, err := CreateChatCompletion(model, messages, 25)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Title = response
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conversation) FormatForExternalPrompting() string {
|
||||
sb := strings.Builder{}
|
||||
messages, err := store.Messages(c)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation %v", c)
|
||||
}
|
||||
for _, message := range messages {
|
||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
|
@ -1,582 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type FunctionResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameter struct {
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameters struct {
|
||||
Type string `json:"type"` // "object"
|
||||
Properties map[string]FunctionParameter `json:"properties"`
|
||||
Required []string `json:"required,omitempty"` // required function parameter names
|
||||
}
|
||||
|
||||
type AvailableTool struct {
|
||||
openai.Tool
|
||||
// The tool's implementation. Returns a string, as tool call results
|
||||
// are treated as normal messages with string contents.
|
||||
Impl func(arguments map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Results are returned as JSON in the following format:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
// result may be an empty array if there are no files in the directory
|
||||
"result": [
|
||||
{"name": "a_file", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
... // more files or directories
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size (in bytes) of the file.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
||||
it make it easier to see which lines to change for other modifications.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Note: only use this tool when you've been explicitly asked to create or write to a file.
|
||||
|
||||
When using this function, you do not need to share the content you intend to write with the user first.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
}`
|
||||
|
||||
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
|
||||
)
|
||||
|
||||
var AvailableTools = map[string]AvailableTool{
|
||||
"read_dir": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"relative_dir": {
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return ReadDir(relativeDir), nil
|
||||
},
|
||||
},
|
||||
"read_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
return ReadFile(path), nil
|
||||
},
|
||||
},
|
||||
"write_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
return WriteFile(path, content), nil
|
||||
},
|
||||
},
|
||||
"file_insert_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"position": {
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "position", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return FileInsertLines(path, position, content), nil
|
||||
},
|
||||
},
|
||||
"file_replace_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"start_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
},
|
||||
"end_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "start_line"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return FileReplaceLines(path, start_line, end_line, content), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func resultToJson(result FunctionResult) string {
|
||||
if result.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
result.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
||||
// returns their results formatted as []Message(s) with role: 'tool' and.
|
||||
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
||||
var toolResults []Message
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Type != "function" {
|
||||
// unsupported tool type
|
||||
continue
|
||||
}
|
||||
|
||||
tool, ok := AvailableTools[toolCall.Function.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
||||
}
|
||||
|
||||
var functionArgs map[string]interface{}
|
||||
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
||||
}
|
||||
|
||||
// TODO: ability to silence this
|
||||
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := tool.Impl(functionArgs)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, Message{
|
||||
Role: "tool",
|
||||
OriginalContent: toolResult,
|
||||
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
||||
// name is not required since the introduction of ToolCallID
|
||||
// hypothesis: by setting it, we inform the model of what a
|
||||
// function's purpose was if future requests omit the function
|
||||
// definition
|
||||
})
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func isPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
||||
}
|
||||
if ok, err := isPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
||||
}
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ReadDir(path string) string {
|
||||
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: dirContents})
|
||||
}
|
||||
|
||||
func ReadFile(path string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{
|
||||
Result: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteFile(path string, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
return resultToJson(FunctionResult{})
|
||||
}
|
||||
|
||||
func FileInsertLines(path string, position int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
}
|
||||
|
||||
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
|
||||
}
|
|
@ -1,162 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||
for _, m := range messages {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Content: m.OriginalContent,
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
message.ToolCallID = m.ToolCallID.String
|
||||
}
|
||||
if m.ToolCalls.Valid {
|
||||
// unmarshal directly into chatMessage.ToolCalls
|
||||
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
||||
if err != nil {
|
||||
// TODO: handle, this shouldn't really happen since
|
||||
// we only save the successfully marshal'd data to database
|
||||
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
||||
}
|
||||
}
|
||||
chatCompletionMessages = append(chatCompletionMessages, message)
|
||||
}
|
||||
|
||||
var tools []openai.Tool
|
||||
for _, t := range AvailableTools {
|
||||
// TODO: support some way to limit which tools are available per-request
|
||||
tools = append(tools, t.Tool)
|
||||
}
|
||||
|
||||
return openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: chatCompletionMessages,
|
||||
MaxTokens: maxTokens,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||
// response. CreateChatCompletion will recursively call itself in the case of
|
||||
// tool calls, until a response is received with the final user-facing output.
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
if choice.Message.Content != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies added
|
||||
// to the original messages
|
||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
// and both returns and streams the response to the provided output channel.
|
||||
// May return a partial response if an error occurs mid-stream.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
if content.String() != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(toolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
133
pkg/cli/store.go
133
pkg/cli/store.go
|
@ -1,133 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role string // one of: 'user', 'assistant', 'tool'
|
||||
CreatedAt time.Time
|
||||
ToolCallID sql.NullString
|
||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
}
|
||||
|
||||
func DataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func NewStore() (*Store, error) {
|
||||
databaseFile := filepath.Join(DataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
|
||||
models := []any{
|
||||
&Conversation{},
|
||||
&Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &Store{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *Store) SaveConversation(conversation *Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteConversation(conversation *Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *Store) SaveMessage(message *Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *Store) Conversations() ([]Conversation, error) {
|
||||
var conversations []Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *Store) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) {
|
||||
var conversation Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *Store) Messages(conversation *Conversation) ([]Message, error) {
|
||||
var messages []Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *Store) LastMessage(conversation *Conversation) (*Message, error) {
|
||||
var message Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
113
pkg/cli/tty.go
113
pkg/cli/tty.go
|
@ -1,113 +0,0 @@
|
|||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/alecthomas/chroma/v2/quick"
|
||||
"github.com/gookit/color"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
fmt.Print("\r")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func HandleDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(messages []Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
message.RenderTTY()
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
||||
// and writes it to stdout.
|
||||
func HighlightMarkdown(markdownText string) error {
|
||||
return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
}
|
||||
|
||||
func (m *Message) RenderTTY() {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
var roleStyle color.Style
|
||||
switch m.Role {
|
||||
case "system":
|
||||
roleStyle = color.Style{color.HiRed}
|
||||
case "user":
|
||||
roleStyle = color.Style{color.HiGreen}
|
||||
case "assistant":
|
||||
roleStyle = color.Style{color.HiBlue}
|
||||
default:
|
||||
roleStyle = color.Style{color.FgWhite}
|
||||
}
|
||||
roleStyle.Add(color.Bold)
|
||||
|
||||
headerColor := color.FgYellow
|
||||
separator := headerColor.Sprint("===")
|
||||
timestamp := headerColor.Sprint(messageAge)
|
||||
role := roleStyle.Sprint(m.FriendlyRole())
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.OriginalContent != "" {
|
||||
HighlightMarkdown(m.OriginalContent)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,48 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "chat [conversation]",
|
||||
Short: "Open the chat interface",
|
||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
shortname := ""
|
||||
if len(args) == 1 {
|
||||
shortname = args[0]
|
||||
}
|
||||
if shortname != ""{
|
||||
_, err := cmdutil.LookupConversationE(ctx, shortname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = tui.Launch(ctx, shortname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "clone <conversation>",
|
||||
Short: "Clone conversations",
|
||||
Long: `Clones the provided conversation.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
toClone, err := cmdutil.LookupConversationE(ctx, shortName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clone, messageCnt, err := ctx.Store.CloneConversation(*toClone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to clone conversation: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s - %s\n", messageCnt, clone.ShortName.String, clone.Title)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
var root = &cobra.Command{
|
||||
Use: "lmcli <command> [flags]",
|
||||
Long: `lmcli - Large Language Model CLI`,
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Usage()
|
||||
},
|
||||
}
|
||||
|
||||
root.AddCommand(
|
||||
ChatCmd(ctx),
|
||||
ContinueCmd(ctx),
|
||||
CloneCmd(ctx),
|
||||
EditCmd(ctx),
|
||||
ListCmd(ctx),
|
||||
NewCmd(ctx),
|
||||
PromptCmd(ctx),
|
||||
RenameCmd(ctx),
|
||||
ReplyCmd(ctx),
|
||||
RetryCmd(ctx),
|
||||
RemoveCmd(ctx),
|
||||
ViewCmd(ctx),
|
||||
)
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func applyGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) {
|
||||
f := cmd.Flags()
|
||||
|
||||
// -m, --model
|
||||
f.StringVarP(
|
||||
ctx.Config.Defaults.Model, "model", "m",
|
||||
*ctx.Config.Defaults.Model, "Which model to generate a response with",
|
||||
)
|
||||
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
||||
})
|
||||
|
||||
// -a, --agent
|
||||
f.StringVarP(&ctx.Config.Defaults.Agent, "agent", "a", ctx.Config.Defaults.Agent, "Which agent to interact with")
|
||||
cmd.RegisterFlagCompletionFunc("agent", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||
return ctx.GetAgents(), cobra.ShellCompDirectiveDefault
|
||||
})
|
||||
|
||||
// --max-length
|
||||
f.IntVar(ctx.Config.Defaults.MaxTokens, "max-length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||
// --temperature
|
||||
f.Float32VarP(ctx.Config.Defaults.Temperature, "temperature", "t", *ctx.Config.Defaults.Temperature, "Sampling temperature")
|
||||
|
||||
// --system-prompt
|
||||
f.StringVar(&ctx.Config.Defaults.SystemPrompt, "system-prompt", ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||
// --system-prompt-file
|
||||
f.StringVar(&ctx.Config.Defaults.SystemPromptFile, "system-prompt-file", ctx.Config.Defaults.SystemPromptFile, "A path to a file containing the system prompt")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
func validateGenerationFlags(ctx *lmcli.Context, cmd *cobra.Command) error {
|
||||
f := cmd.Flags()
|
||||
model, err := f.GetString("model")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing --model: %w", err)
|
||||
}
|
||||
if model != "" && !slices.Contains(ctx.GetModels(), model) {
|
||||
return fmt.Errorf("Unknown model: %s", model)
|
||||
}
|
||||
|
||||
agent, err := f.GetString("agent")
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error parsing --agent: %w", err)
|
||||
}
|
||||
if agent != "" && !slices.Contains(ctx.GetAgents(), agent) {
|
||||
return fmt.Errorf("Unknown agent: %s", agent)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Join(args, " ")
|
||||
}
|
||||
message = strings.Trim(message, " \t\n")
|
||||
return
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continue a conversation from the last message",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) < 2 {
|
||||
return fmt.Errorf("conversation expected to have at least 2 messages")
|
||||
}
|
||||
|
||||
lastMessage := &messages[len(messages)-1]
|
||||
if lastMessage.Role != api.MessageRoleAssistant {
|
||||
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||
}
|
||||
|
||||
// Output the contents of the last message so far
|
||||
fmt.Print(lastMessage.Content)
|
||||
|
||||
// Submit the LLM request, allowing it to continue the last message
|
||||
continuedOutput, err := cmdutil.Prompt(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||
}
|
||||
|
||||
// Append the new response to the original message
|
||||
lastMessage.Content += strings.TrimRight(continuedOutput.Content, "\n\t ")
|
||||
|
||||
// Update the original message
|
||||
err = ctx.Store.UpdateMessage(lastMessage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not update the last message: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "edit <conversation>",
|
||||
Short: "Edit the last user reply in a conversation",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
offset, _ := cmd.Flags().GetInt("offset")
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
if offset > len(messages)-1 {
|
||||
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||
}
|
||||
|
||||
desiredIdx := len(messages) - 1 - offset
|
||||
toEdit := messages[desiredIdx]
|
||||
|
||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||
switch newContents {
|
||||
case toEdit.Content:
|
||||
return fmt.Errorf("No edits were made.")
|
||||
case "":
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
toEdit.Content = newContents
|
||||
|
||||
role, _ := cmd.Flags().GetString("role")
|
||||
if role != "" {
|
||||
if role != string(api.MessageRoleUser) && role != string(api.MessageRoleAssistant) {
|
||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||
}
|
||||
toEdit.Role = api.MessageRole(role)
|
||||
}
|
||||
|
||||
// Update the message in-place
|
||||
inplace, _ := cmd.Flags().GetBool("in-place")
|
||||
if inplace {
|
||||
return ctx.Store.UpdateMessage(&toEdit)
|
||||
}
|
||||
|
||||
// Otherwise, create a branch for the edited message
|
||||
message, _, err := ctx.Store.CloneBranch(toEdit)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if desiredIdx > 0 {
|
||||
// update selected reply
|
||||
messages[desiredIdx-1].SelectedReply = message
|
||||
err = ctx.Store.UpdateMessage(&messages[desiredIdx-1])
|
||||
} else {
|
||||
// update selected root
|
||||
conversation.SelectedRoot = message
|
||||
err = ctx.Store.UpdateConversation(conversation)
|
||||
}
|
||||
return err
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolP("in-place", "i", true, "Edit the message in-place, rather than creating a branch")
|
||||
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||
cmd.Flags().StringP("role", "r", "", "Change the role of the edited message (user or assistant)")
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
LS_COUNT int = 5
|
||||
)
|
||||
|
||||
func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List conversations",
|
||||
Long: `List conversations in order of recent activity`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
messages, err := ctx.Store.LatestConversationMessages()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
name string
|
||||
cutoff time.Duration
|
||||
}
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
dayOfWeek := int(now.Weekday())
|
||||
categories := []Category{
|
||||
{"today", now.Sub(midnight)},
|
||||
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||
{"this month", now.Sub(monthStart)},
|
||||
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||
{"older", now.Sub(time.Time{})},
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
all, _ := cmd.Flags().GetBool("all")
|
||||
|
||||
for _, message := range messages {
|
||||
messageAge := now.Sub(message.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, c := range categories {
|
||||
if messageAge < c.cutoff {
|
||||
category = c.name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
message.Conversation.ShortName.String,
|
||||
util.HumanTimeElapsedSince(messageAge),
|
||||
message.Conversation.Title,
|
||||
)
|
||||
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
count, _ := cmd.Flags().GetInt("count")
|
||||
var conversationsPrinted int
|
||||
outer:
|
||||
for _, category := range categories {
|
||||
conversationLines, ok := categorized[category.name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("%s:\n", category.name)
|
||||
for _, conv := range conversationLines {
|
||||
if conversationsPrinted >= count && !all {
|
||||
fmt.Printf("%d remaining conversation(s), use --all to view.\n", len(messages)-conversationsPrinted)
|
||||
break outer
|
||||
}
|
||||
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
conversationsPrinted++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolP("all", "a", false, "Show all conversations")
|
||||
cmd.Flags().IntP("count", "c", LS_COUNT, "How many conversations to show")
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,56 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
input := inputFromArgsOrEditor(args, "# Start a new conversation below\n", "")
|
||||
if input == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
messages := []api.Message{{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: input,
|
||||
}}
|
||||
|
||||
conversation, messages, err := ctx.Store.StartConversation(messages...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not start a new conversation: %v", err)
|
||||
}
|
||||
|
||||
cmdutil.HandleReply(ctx, &messages[len(messages)-1], true)
|
||||
|
||||
title, err := cmdutil.GenerateTitle(ctx, messages)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not generate title for conversation %s: %v\n", conversation.ShortName.String, err)
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
input := inputFromArgsOrEditor(args, "# Write your prompt below\n", "")
|
||||
if input == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
messages := []api.Message{{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: input,
|
||||
}}
|
||||
|
||||
_, err = cmdutil.Prompt(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rm <conversation>...",
|
||||
Short: "Remove conversations",
|
||||
Long: `Remove conversations by their short names.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var toRemove []*api.Conversation
|
||||
for _, shortName := range args {
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
toRemove = append(toRemove, conversation)
|
||||
}
|
||||
var errors []error
|
||||
for _, c := range toRemove {
|
||||
err := ctx.Store.DeleteConversation(c)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||
}
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Could not remove some conversations: %v", errors)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
var completions []string
|
||||
outer:
|
||||
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
||||
parts := strings.Split(completion, "\t")
|
||||
for _, arg := range args {
|
||||
if parts[0] == arg {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
completions = append(completions, completion)
|
||||
}
|
||||
return completions, compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
)
|
||||
|
||||
func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rename <conversation> [title]",
|
||||
Short: "Rename a conversation",
|
||||
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
var err error
|
||||
var title string
|
||||
|
||||
generate, _ := cmd.Flags().GetBool("generate")
|
||||
if generate {
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
title, err = cmdutil.GenerateTitle(ctx, messages)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||
}
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Conversation title not provided.")
|
||||
}
|
||||
title = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not update conversation title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("generate", false, "Generate a conversation title")
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||
if reply == "" {
|
||||
return fmt.Errorf("No reply was provided.")
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, api.Message{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: reply,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retry the last user reply in a conversation",
|
||||
Long: `Prompt the conversation from the last user response.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
err := validateGenerationFlags(ctx, cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
// Load the complete thread from the root message
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
offset, _ := cmd.Flags().GetInt("offset")
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
if offset > len(messages)-1 {
|
||||
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||
}
|
||||
|
||||
retryFromIdx := len(messages) - 1 - offset
|
||||
|
||||
// decrease retryFromIdx until we hit a user message
|
||||
for retryFromIdx >= 0 && messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||
retryFromIdx--
|
||||
}
|
||||
|
||||
if messages[retryFromIdx].Role != api.MessageRoleUser {
|
||||
return fmt.Errorf("No user messages to retry")
|
||||
}
|
||||
|
||||
fmt.Printf("Idx: %d Message: %v\n", retryFromIdx, messages[retryFromIdx])
|
||||
|
||||
// Start a new branch at the last user message
|
||||
cmdutil.HandleReply(ctx, &messages[retryFromIdx], true)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Int("offset", 0, "Offset from the last message to retry from.")
|
||||
|
||||
applyGenerationFlags(ctx, cmd)
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,337 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Prompt prompts the configured the configured model and streams the response
|
||||
// to stdout. Returns all model reply messages.
|
||||
func Prompt(ctx *lmcli.Context, messages []api.Message, callback func(api.Message)) (*api.Message, error) {
|
||||
m, provider, err := ctx.GetModelProvider(*ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := api.RequestParameters{
|
||||
Model: m,
|
||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *ctx.Config.Defaults.Temperature,
|
||||
}
|
||||
|
||||
system := ctx.DefaultSystemPrompt()
|
||||
|
||||
agent := ctx.GetAgent(ctx.Config.Defaults.Agent)
|
||||
if agent != nil {
|
||||
if agent.SystemPrompt != "" {
|
||||
system = agent.SystemPrompt
|
||||
}
|
||||
params.Toolbox = agent.Toolbox
|
||||
}
|
||||
|
||||
if system != "" {
|
||||
messages = api.ApplySystemPrompt(messages, system, false)
|
||||
}
|
||||
|
||||
content := make(chan api.Chunk)
|
||||
defer close(content)
|
||||
|
||||
// render the content received over the channel
|
||||
go ShowDelayedContent(content)
|
||||
|
||||
reply, err := provider.CreateChatCompletionStream(
|
||||
context.Background(), params, messages, content,
|
||||
)
|
||||
|
||||
if reply.Content != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
|
||||
if err != nil {
|
||||
lmcli.Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return reply, err
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
// short name or exits the program
|
||||
func LookupConversation(ctx *lmcli.Context, shortName string) *api.Conversation {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
lmcli.Fatal("Conversation not found: %s\n", shortName)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*api.Conversation, error) {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return nil, fmt.Errorf("Conversation not found: %s", shortName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func HandleConversationReply(ctx *lmcli.Context, c *api.Conversation, persist bool, toSend ...api.Message) {
|
||||
messages, err := ctx.Store.PathToLeaf(c.SelectedRoot)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||
}
|
||||
HandleReply(ctx, &messages[len(messages)-1], persist, toSend...)
|
||||
}
|
||||
|
||||
// handleConversationReply handles sending messages to an existing
|
||||
// conversation, optionally persisting both the sent replies and responses.
|
||||
func HandleReply(ctx *lmcli.Context, to *api.Message, persist bool, messages ...api.Message) {
|
||||
if to == nil {
|
||||
lmcli.Fatal("Can't prompt from an empty message.")
|
||||
}
|
||||
|
||||
existing, err := ctx.Store.PathToRoot(to)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not load messages: %v\n", err)
|
||||
}
|
||||
|
||||
RenderConversation(ctx, append(existing, messages...), true)
|
||||
|
||||
var savedReplies []api.Message
|
||||
if persist && len(messages) > 0 {
|
||||
savedReplies, err = ctx.Store.Reply(to, messages...)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save messages: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&api.Message{Role: api.MessageRoleAssistant}))
|
||||
|
||||
var lastSavedMessage *api.Message
|
||||
lastSavedMessage = to
|
||||
if len(savedReplies) > 0 {
|
||||
lastSavedMessage = &savedReplies[len(savedReplies)-1]
|
||||
}
|
||||
|
||||
replyCallback := func(reply api.Message) {
|
||||
if !persist {
|
||||
return
|
||||
}
|
||||
savedReplies, err = ctx.Store.Reply(lastSavedMessage, reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
lastSavedMessage = &savedReplies[0]
|
||||
}
|
||||
|
||||
_, err = Prompt(ctx, append(existing, messages...), replyCallback)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func FormatForExternalPrompt(messages []api.Message, system bool) string {
|
||||
sb := strings.Builder{}
|
||||
for _, message := range messages {
|
||||
if message.Content == "" {
|
||||
continue
|
||||
}
|
||||
switch message.Role {
|
||||
case api.MessageRoleAssistant, api.MessageRoleToolCall:
|
||||
sb.WriteString("Assistant:\n\n")
|
||||
case api.MessageRoleUser:
|
||||
sb.WriteString("User:\n\n")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GenerateTitle(ctx *lmcli.Context, messages []api.Message) (string, error) {
|
||||
const systemPrompt = `You will be shown a conversation between a user and an AI assistant. Your task is to generate a short title (8 words or less) for the provided conversation that reflects the conversation's topic. Your response is expected to be in JSON in the format shown below.
|
||||
|
||||
Example conversation:
|
||||
|
||||
[{"role": "user", "content": "Can you help me with my math homework?"},{"role": "assistant", "content": "Sure, what topic are you struggling with?"}]
|
||||
|
||||
Example response:
|
||||
|
||||
{"title": "Help with math homework"}
|
||||
`
|
||||
type msg struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
var msgs []msg
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case api.MessageRoleAssistant, api.MessageRoleUser:
|
||||
msgs = append(msgs, msg{string(m.Role), m.Content})
|
||||
}
|
||||
}
|
||||
|
||||
// Serialize the conversation to JSON
|
||||
conversation, err := json.Marshal(msgs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
generateRequest := []api.Message{
|
||||
{
|
||||
Role: api.MessageRoleSystem,
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: string(conversation),
|
||||
},
|
||||
}
|
||||
|
||||
m, provider, err := ctx.GetModelProvider(
|
||||
*ctx.Config.Conversations.TitleGenerationModel,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := api.RequestParameters{
|
||||
Model: m,
|
||||
MaxTokens: 25,
|
||||
}
|
||||
|
||||
response, err := provider.CreateChatCompletion(
|
||||
context.Background(), requestParams, generateRequest,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
var jsonResponse struct {
|
||||
Title string `json:"title"`
|
||||
}
|
||||
err = json.Unmarshal([]byte(response.Content), &jsonResponse)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return jsonResponse.Title, nil
|
||||
}
|
||||
|
||||
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// notify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
// Save the current cursor position
|
||||
fmt.Print("\033[s")
|
||||
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
// Relmcli the cursor position
|
||||
fmt.Print("\033[u")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
// Move the cursor to the saved position
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\033[u")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func ShowDelayedContent(content <-chan api.Chunk) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(ctx *lmcli.Context, messages []api.Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
RenderMessage(ctx, &message)
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RenderMessage(ctx *lmcli.Context, m *api.Message) {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||
|
||||
switch m.Role {
|
||||
case api.MessageRoleSystem:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||
case api.MessageRoleUser:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||
case api.MessageRoleAssistant:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||
}
|
||||
|
||||
role := headerStyle.Render(m.Role.FriendlyRole())
|
||||
|
||||
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
||||
separator := separatorStyle.Render("===")
|
||||
timestamp := separatorStyle.Render(messageAge)
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.Content != "" {
|
||||
ctx.Chroma.Highlight(os.Stdout, m.Content)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.PathToLeaf(conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation %s: %v", conversation.ShortName.String, err)
|
||||
}
|
||||
|
||||
cmdutil.RenderConversation(ctx, messages, false)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -0,0 +1,76 @@
|
|||
package lmcli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Defaults *struct {
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||
Temperature *float32 `yaml:"temperature" default:"0.2"`
|
||||
SystemPrompt string `yaml:"systemPrompt,omitempty"`
|
||||
SystemPromptFile string `yaml:"systemPromptFile,omitempty"`
|
||||
// CLI only
|
||||
Agent string `yaml:"-"`
|
||||
} `yaml:"defaults"`
|
||||
Conversations *struct {
|
||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||
} `yaml:"conversations"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
Agents []*struct {
|
||||
Name string `yaml:"name"`
|
||||
SystemPrompt string `yaml:"systemPrompt"`
|
||||
Tools []string `yaml:"tools"`
|
||||
} `yaml:"agents"`
|
||||
Providers []*struct {
|
||||
Name string `yaml:"name,omitempty"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl,omitempty"`
|
||||
APIKey string `yaml:"apiKey,omitempty"`
|
||||
Models []string `yaml:"models"`
|
||||
} `yaml:"providers"`
|
||||
}
|
||||
|
||||
func NewConfig(configFile string) (*Config, error) {
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
configExists := true
|
||||
configBytes, err := os.ReadFile(configFile)
|
||||
if os.IsNotExist(err) {
|
||||
configExists = false
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||
} else {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
}
|
||||
|
||||
shouldWriteDefaults = util.SetStructDefaults(c)
|
||||
if !configExists || shouldWriteDefaults {
|
||||
if configExists {
|
||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
||||
os.Rename(configFile, configFile+".bak")
|
||||
}
|
||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
}
|
||||
encoder := yaml.NewEncoder(file)
|
||||
encoder.SetIndent(2)
|
||||
err = encoder.Encode(c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
|
@ -0,0 +1,229 @@
|
|||
package lmcli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/google"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/ollama"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Agent struct {
|
||||
Name string
|
||||
SystemPrompt string
|
||||
Toolbox []api.ToolSpec
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
// high level app configuration, may be mutated at runtime
|
||||
Config Config
|
||||
Store ConversationStore
|
||||
Chroma *tty.ChromaHighlighter
|
||||
}
|
||||
|
||||
func NewContext() (*Context, error) {
|
||||
configFile := filepath.Join(configDir(), "config.yaml")
|
||||
config, err := NewConfig(configFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{
|
||||
//Logger: logger.Default.LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
store, err := NewSQLStore(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
|
||||
return &Context{*config, store, chroma}, nil
|
||||
}
|
||||
|
||||
func (c *Context) GetModels() (models []string) {
|
||||
modelCounts := make(map[string]int)
|
||||
for _, p := range c.Config.Providers {
|
||||
name := p.Kind
|
||||
if p.Name != "" {
|
||||
name = p.Name
|
||||
}
|
||||
|
||||
for _, m := range p.Models {
|
||||
modelCounts[m]++
|
||||
models = append(models, fmt.Sprintf("%s@%s", m, name))
|
||||
}
|
||||
}
|
||||
|
||||
for m, c := range modelCounts {
|
||||
if c == 1 {
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetAgents() (agents []string) {
|
||||
for _, p := range c.Config.Agents {
|
||||
agents = append(agents, p.Name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetAgent(name string) *Agent {
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, a := range c.Config.Agents {
|
||||
if name != a.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
var enabledTools []api.ToolSpec
|
||||
for _, toolName := range a.Tools {
|
||||
tool, ok := agents.AvailableTools[toolName]
|
||||
if ok {
|
||||
enabledTools = append(enabledTools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
return &Agent{
|
||||
Name: a.Name,
|
||||
SystemPrompt: a.SystemPrompt,
|
||||
Toolbox: enabledTools,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) DefaultSystemPrompt() string {
|
||||
if c.Config.Defaults.SystemPromptFile != "" {
|
||||
content, err := util.ReadFileContents(c.Config.Defaults.SystemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v\n", c.Config.Defaults.SystemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return c.Config.Defaults.SystemPrompt
|
||||
}
|
||||
|
||||
func (c *Context) GetModelProvider(model string) (string, api.ChatCompletionProvider, error) {
|
||||
parts := strings.Split(model, "@")
|
||||
|
||||
var provider string
|
||||
if len(parts) > 1 {
|
||||
model = parts[0]
|
||||
provider = parts[1]
|
||||
}
|
||||
|
||||
for _, p := range c.Config.Providers {
|
||||
name := p.Kind
|
||||
if p.Name != "" {
|
||||
name = p.Name
|
||||
}
|
||||
|
||||
if provider != "" && name != provider {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, m := range p.Models {
|
||||
if m == model {
|
||||
switch p.Kind {
|
||||
case "anthropic":
|
||||
url := "https://api.anthropic.com"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &anthropic.AnthropicClient{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
case "google":
|
||||
url := "https://generativelanguage.googleapis.com"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &google.Client{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
case "ollama":
|
||||
url := "http://localhost:11434/api"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &ollama.OllamaClient{
|
||||
BaseURL: url,
|
||||
}, nil
|
||||
case "openai":
|
||||
url := "https://api.openai.com"
|
||||
if p.BaseURL != "" {
|
||||
url = p.BaseURL
|
||||
}
|
||||
return model, &openai.OpenAIClient{
|
||||
BaseURL: url,
|
||||
APIKey: p.APIKey,
|
||||
}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unknown provider kind: %s", p.Kind)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return "", nil, fmt.Errorf("unknown model: %s", model)
|
||||
}
|
||||
|
||||
func configDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func dataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
|
@ -0,0 +1,433 @@
|
|||
package lmcli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
ConversationByShortName(shortName string) (*api.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
RootMessages(conversationID uint) ([]api.Message, error)
|
||||
LatestConversationMessages() ([]api.Message, error)
|
||||
|
||||
StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error)
|
||||
UpdateConversation(conversation *api.Conversation) error
|
||||
DeleteConversation(conversation *api.Conversation) error
|
||||
CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error)
|
||||
|
||||
MessageByID(messageID uint) (*api.Message, error)
|
||||
MessageReplies(messageID uint) ([]api.Message, error)
|
||||
|
||||
UpdateMessage(message *api.Message) error
|
||||
DeleteMessage(message *api.Message, prune bool) error
|
||||
CloneBranch(toClone api.Message) (*api.Message, uint, error)
|
||||
Reply(to *api.Message, messages ...api.Message) ([]api.Message, error)
|
||||
|
||||
PathToRoot(message *api.Message) ([]api.Message, error)
|
||||
PathToLeaf(message *api.Message) ([]api.Message, error)
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
models := []any{
|
||||
&api.Conversation{},
|
||||
&api.Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) createConversation() (*api.Conversation, error) {
|
||||
// Create the new conversation
|
||||
c := &api.Conversation{}
|
||||
err := s.db.Save(c).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Generate and save its "short name"
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(c.ID)})
|
||||
c.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Updates(c).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateConversation(c *api.Conversation) error {
|
||||
if c == nil || c.ID == 0 {
|
||||
return fmt.Errorf("Conversation is nil or invalid (missing ID)")
|
||||
}
|
||||
return s.db.Updates(c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(c *api.Conversation) error {
|
||||
// Delete messages first
|
||||
err := s.db.Where("conversation_id = ?", c.ID).Delete(&api.Message{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.db.Delete(c).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *api.Message, prune bool) error {
|
||||
panic("Not yet implemented")
|
||||
//return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(m *api.Message) error {
|
||||
if m == nil || m.ID == 0 {
|
||||
return fmt.Errorf("Message is nil or invalid (missing ID)")
|
||||
}
|
||||
return s.db.Updates(m).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var conversations []api.Conversation
|
||||
// ignore error for completions
|
||||
s.db.Find(&conversations)
|
||||
completions := make([]string, 0, len(conversations))
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationByShortName(shortName string) (*api.Conversation, error) {
|
||||
if shortName == "" {
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation api.Conversation
|
||||
err := s.db.Preload("SelectedRoot").Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) RootMessages(conversationID uint) ([]api.Message, error) {
|
||||
var rootMessages []api.Message
|
||||
err := s.db.Where("conversation_id = ? AND parent_id IS NULL", conversationID).Find(&rootMessages).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rootMessages, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageByID(messageID uint) (*api.Message, error) {
|
||||
var message api.Message
|
||||
err := s.db.Preload("Parent").Preload("Replies").Preload("SelectedReply").Where("id = ?", messageID).Find(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) MessageReplies(messageID uint) ([]api.Message, error) {
|
||||
var replies []api.Message
|
||||
err := s.db.Where("parent_id = ?", messageID).Find(&replies).Error
|
||||
return replies, err
|
||||
}
|
||||
|
||||
// StartConversation starts a new conversation with the provided messages
|
||||
func (s *SQLStore) StartConversation(messages ...api.Message) (*api.Conversation, []api.Message, error) {
|
||||
if len(messages) == 0 {
|
||||
return nil, nil, fmt.Errorf("Must provide at least 1 message")
|
||||
}
|
||||
|
||||
// Create new conversation
|
||||
conversation, err := s.createConversation()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Create first message
|
||||
messages[0].Conversation = conversation
|
||||
err = s.db.Create(&messages[0]).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Update conversation's selected root message
|
||||
conversation.SelectedRoot = &messages[0]
|
||||
err = s.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Add additional replies to conversation
|
||||
if len(messages) > 1 {
|
||||
newMessages, err := s.Reply(&messages[0], messages[1:]...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
messages = append([]api.Message{messages[0]}, newMessages...)
|
||||
}
|
||||
return conversation, messages, nil
|
||||
}
|
||||
|
||||
// CloneConversation clones the given conversation and all of its root meesages
|
||||
func (s *SQLStore) CloneConversation(toClone api.Conversation) (*api.Conversation, uint, error) {
|
||||
rootMessages, err := s.RootMessages(toClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
clone, err := s.createConversation()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not create clone: %s", err)
|
||||
}
|
||||
clone.Title = toClone.Title + " - Clone"
|
||||
|
||||
var errors []error
|
||||
var messageCnt uint = 0
|
||||
for _, root := range rootMessages {
|
||||
messageCnt++
|
||||
newRoot := root
|
||||
newRoot.ConversationID = &clone.ID
|
||||
|
||||
cloned, count, err := s.CloneBranch(newRoot)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
messageCnt += count
|
||||
|
||||
if root.ID == *toClone.SelectedRootID {
|
||||
clone.SelectedRootID = &cloned.ID
|
||||
if err := s.UpdateConversation(clone); err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not set selected root on clone: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return nil, 0, fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
return clone, messageCnt, nil
|
||||
}
|
||||
|
||||
// Reply to a message with a series of messages (each following the next)
|
||||
func (s *SQLStore) Reply(to *api.Message, messages ...api.Message) ([]api.Message, error) {
|
||||
var savedMessages []api.Message
|
||||
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
currentParent := to
|
||||
for i := range messages {
|
||||
parent := currentParent
|
||||
message := messages[i]
|
||||
message.Parent = parent
|
||||
message.Conversation = parent.Conversation
|
||||
message.ID = 0
|
||||
message.CreatedAt = time.Time{}
|
||||
|
||||
if err := tx.Create(&message).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update parent selected reply
|
||||
parent.Replies = append(parent.Replies, message)
|
||||
parent.SelectedReply = &message
|
||||
if err := tx.Model(parent).Update("selected_reply_id", message.ID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
savedMessages = append(savedMessages, message)
|
||||
currentParent = &message
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
return savedMessages, err
|
||||
}
|
||||
|
||||
// CloneBranch returns a deep clone of the given message and its replies, returning
|
||||
// a new message object. The new message will be attached to the same parent as
|
||||
// the messageToClone
|
||||
func (s *SQLStore) CloneBranch(messageToClone api.Message) (*api.Message, uint, error) {
|
||||
newMessage := messageToClone
|
||||
newMessage.ID = 0
|
||||
newMessage.Replies = nil
|
||||
newMessage.SelectedReplyID = nil
|
||||
newMessage.SelectedReply = nil
|
||||
|
||||
originalReplies, err := s.MessageReplies(messageToClone.ID)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not fetch message %d replies: %v", messageToClone.ID, err)
|
||||
}
|
||||
|
||||
if err := s.db.Create(&newMessage).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not clone message: %s", err)
|
||||
}
|
||||
|
||||
var replyCount uint = 0
|
||||
for _, reply := range originalReplies {
|
||||
replyCount++
|
||||
|
||||
newReply := reply
|
||||
newReply.ConversationID = messageToClone.ConversationID
|
||||
newReply.ParentID = &newMessage.ID
|
||||
newReply.Parent = &newMessage
|
||||
|
||||
res, c, err := s.CloneBranch(newReply)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
newMessage.Replies = append(newMessage.Replies, *res)
|
||||
replyCount += c
|
||||
|
||||
if reply.ID == *messageToClone.SelectedReplyID {
|
||||
newMessage.SelectedReplyID = &res.ID
|
||||
if err := s.UpdateMessage(&newMessage); err != nil {
|
||||
return nil, 0, fmt.Errorf("Could not update parent select reply ID: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return &newMessage, replyCount, nil
|
||||
}
|
||||
|
||||
func fetchMessages(db *gorm.DB) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
if err := db.Preload("Conversation").Find(&messages).Error; err != nil {
|
||||
return nil, fmt.Errorf("Could not fetch messages: %v", err)
|
||||
}
|
||||
|
||||
messageMap := make(map[uint]api.Message)
|
||||
for i, message := range messages {
|
||||
messageMap[messages[i].ID] = message
|
||||
}
|
||||
|
||||
// Create a map to store replies by their parent ID
|
||||
repliesMap := make(map[uint][]api.Message)
|
||||
for i, message := range messages {
|
||||
if messages[i].ParentID != nil {
|
||||
repliesMap[*messages[i].ParentID] = append(repliesMap[*messages[i].ParentID], message)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign replies, parent, and selected reply to each message
|
||||
for i := range messages {
|
||||
if replies, exists := repliesMap[messages[i].ID]; exists {
|
||||
messages[i].Replies = make([]api.Message, len(replies))
|
||||
for j, m := range replies {
|
||||
messages[i].Replies[j] = m
|
||||
}
|
||||
}
|
||||
if messages[i].ParentID != nil {
|
||||
if parent, exists := messageMap[*messages[i].ParentID]; exists {
|
||||
messages[i].Parent = &parent
|
||||
}
|
||||
}
|
||||
if messages[i].SelectedReplyID != nil {
|
||||
if selectedReply, exists := messageMap[*messages[i].SelectedReplyID]; exists {
|
||||
messages[i].SelectedReply = &selectedReply
|
||||
}
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) buildPath(message *api.Message, getNext func(*api.Message) *uint) ([]api.Message, error) {
|
||||
var messages []api.Message
|
||||
messages, err := fetchMessages(s.db.Where("conversation_id = ?", message.ConversationID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a map to store messages by their ID
|
||||
messageMap := make(map[uint]*api.Message)
|
||||
for i := range messages {
|
||||
messageMap[messages[i].ID] = &messages[i]
|
||||
}
|
||||
|
||||
// Build the path
|
||||
var path []api.Message
|
||||
nextID := &message.ID
|
||||
|
||||
for {
|
||||
current, exists := messageMap[*nextID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("Message with ID %d not found in conversation", *nextID)
|
||||
}
|
||||
|
||||
path = append(path, *current)
|
||||
|
||||
nextID = getNext(current)
|
||||
if nextID == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathToRoot traverses the provided message's Parent until reaching the tree
|
||||
// root and returns a slice of all messages traversed in chronological order
|
||||
// (starting with the root and ending with the message provided)
|
||||
func (s *SQLStore) PathToRoot(message *api.Message) ([]api.Message, error) {
|
||||
if message == nil || message.ID <= 0 {
|
||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||
}
|
||||
|
||||
path, err := s.buildPath(message, func(m *api.Message) *uint {
|
||||
return m.ParentID
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slices.Reverse(path)
|
||||
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// PathToLeaf traverses the provided message's SelectedReply until reaching a
|
||||
// tree leaf and returns a slice of all messages traversed in chronological
|
||||
// order (starting with the message provided and ending with the leaf)
|
||||
func (s *SQLStore) PathToLeaf(message *api.Message) ([]api.Message, error) {
|
||||
if message == nil || message.ID <= 0 {
|
||||
return nil, fmt.Errorf("Message is nil or has invalid ID")
|
||||
}
|
||||
|
||||
return s.buildPath(message, func(m *api.Message) *uint {
|
||||
return m.SelectedReplyID
|
||||
})
|
||||
}
|
||||
|
||||
func (s *SQLStore) LatestConversationMessages() ([]api.Message, error) {
|
||||
var latestMessages []api.Message
|
||||
|
||||
subQuery := s.db.Model(&api.Message{}).
|
||||
Select("MAX(created_at) as max_created_at, conversation_id").
|
||||
Group("conversation_id")
|
||||
|
||||
err := s.db.Model(&api.Message{}).
|
||||
Joins("JOIN (?) as sub on messages.conversation_id = sub.conversation_id AND messages.created_at = sub.max_created_at", subQuery).
|
||||
Group("messages.conversation_id").
|
||||
Order("created_at DESC").
|
||||
Preload("Conversation").
|
||||
Find(&latestMessages).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return latestMessages, nil
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
package bubbles
|
||||
|
||||
import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
type ConfirmPrompt struct {
|
||||
Question string
|
||||
Style lipgloss.Style
|
||||
Payload interface{}
|
||||
value bool
|
||||
answered bool
|
||||
focused bool
|
||||
}
|
||||
|
||||
func NewConfirmPrompt(question string, payload interface{}) ConfirmPrompt {
|
||||
return ConfirmPrompt{
|
||||
Question: question,
|
||||
Style: lipgloss.NewStyle(),
|
||||
Payload: payload,
|
||||
focused: true, // focus by default
|
||||
}
|
||||
}
|
||||
|
||||
type MsgConfirmPromptAnswered struct {
|
||||
Value bool
|
||||
Payload interface{}
|
||||
}
|
||||
|
||||
func (b ConfirmPrompt) Update(msg tea.Msg) (ConfirmPrompt, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
if !b.focused || b.answered {
|
||||
return b, nil
|
||||
}
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
b.value = true
|
||||
b.answered = true
|
||||
b.focused = false
|
||||
return b, func() tea.Msg { return MsgConfirmPromptAnswered{true, b.Payload} }
|
||||
case "n", "N", "esc":
|
||||
b.value = false
|
||||
b.answered = true
|
||||
b.focused = false
|
||||
return b, func() tea.Msg { return MsgConfirmPromptAnswered{false, b.Payload} }
|
||||
}
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b ConfirmPrompt) View() string {
|
||||
return b.Style.Render(b.Question) + lipgloss.NewStyle().Faint(true).Render(" (y/n)")
|
||||
}
|
||||
|
||||
func (b *ConfirmPrompt) Focus() {
|
||||
b.focused = true
|
||||
}
|
||||
|
||||
func (b *ConfirmPrompt) Blur() {
|
||||
b.focused = false
|
||||
}
|
||||
|
||||
func (b ConfirmPrompt) Focused() bool {
|
||||
return b.focused
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type Values struct {
|
||||
ConvShortname string
|
||||
}
|
||||
|
||||
type Shared struct {
|
||||
Ctx *lmcli.Context
|
||||
Values *Values
|
||||
Width int
|
||||
Height int
|
||||
Err error
|
||||
}
|
||||
|
||||
// a convenience struct for holding rendered content for indiviudal UI
|
||||
// elements
|
||||
type Sections struct {
|
||||
Header string
|
||||
Content string
|
||||
Error string
|
||||
Input string
|
||||
Footer string
|
||||
}
|
||||
|
||||
type (
|
||||
// send to change the current state
|
||||
MsgViewChange View
|
||||
// sent to a state when it is entered
|
||||
MsgViewEnter struct{}
|
||||
// sent when an error occurs
|
||||
MsgError error
|
||||
)
|
||||
|
||||
func WrapError(err error) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return MsgError(err)
|
||||
}
|
||||
}
|
||||
|
||||
type View int
|
||||
|
||||
const (
|
||||
StateChat View = iota
|
||||
StateConversations
|
||||
//StateSettings
|
||||
//StateHelp
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
package styles
|
||||
|
||||
import "github.com/charmbracelet/lipgloss"
|
||||
|
||||
var Header = lipgloss.NewStyle().
|
||||
PaddingLeft(1).
|
||||
PaddingRight(1).
|
||||
Background(lipgloss.Color("0"))
|
|
@ -0,0 +1,133 @@
|
|||
package tui
|
||||
|
||||
// The terminal UI for lmcli, launched from the `lmcli chat` command
|
||||
// TODO:
|
||||
// - change model
|
||||
// - rename conversation
|
||||
// - set system prompt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/views/chat"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/views/conversations"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// Application model
|
||||
type Model struct {
|
||||
shared.Shared
|
||||
|
||||
state shared.View
|
||||
chat chat.Model
|
||||
conversations conversations.Model
|
||||
}
|
||||
|
||||
func initialModel(ctx *lmcli.Context, values shared.Values) Model {
|
||||
m := Model{
|
||||
Shared: shared.Shared{
|
||||
Ctx: ctx,
|
||||
Values: &values,
|
||||
},
|
||||
}
|
||||
|
||||
m.state = shared.StateChat
|
||||
m.chat = chat.Chat(m.Shared)
|
||||
m.conversations = conversations.Conversations(m.Shared)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
m.conversations.Init(),
|
||||
m.chat.Init(),
|
||||
func() tea.Msg {
|
||||
return shared.MsgViewChange(m.state)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Model) handleGlobalInput(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
// delegate input to the active child state first, only handling it at the
|
||||
// global level if the child state does not
|
||||
var cmds []tea.Cmd
|
||||
switch m.state {
|
||||
case shared.StateChat:
|
||||
handled, cmd := m.chat.HandleInput(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
if handled {
|
||||
m.chat, cmd = m.chat.Update(nil)
|
||||
cmds = append(cmds, cmd)
|
||||
return true, tea.Batch(cmds...)
|
||||
}
|
||||
case shared.StateConversations:
|
||||
handled, cmd := m.conversations.HandleInput(msg)
|
||||
cmds = append(cmds, cmd)
|
||||
if handled {
|
||||
m.conversations, cmd = m.conversations.Update(nil)
|
||||
cmds = append(cmds, cmd)
|
||||
return true, tea.Batch(cmds...)
|
||||
}
|
||||
}
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "ctrl+q":
|
||||
return true, tea.Quit
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
handled, cmd := m.handleGlobalInput(msg)
|
||||
if handled {
|
||||
return m, cmd
|
||||
}
|
||||
case shared.MsgViewChange:
|
||||
m.state = shared.View(msg)
|
||||
switch m.state {
|
||||
case shared.StateChat:
|
||||
m.chat.HandleResize(m.Width, m.Height)
|
||||
case shared.StateConversations:
|
||||
m.conversations.HandleResize(m.Width, m.Height)
|
||||
}
|
||||
return m, func() tea.Msg { return shared.MsgViewEnter(struct{}{}) }
|
||||
case tea.WindowSizeMsg:
|
||||
m.Width, m.Height = msg.Width, msg.Height
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
switch m.state {
|
||||
case shared.StateConversations:
|
||||
m.conversations, cmd = m.conversations.Update(msg)
|
||||
case shared.StateChat:
|
||||
m.chat, cmd = m.chat.Update(msg)
|
||||
}
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m Model) View() string {
|
||||
switch m.state {
|
||||
case shared.StateConversations:
|
||||
return m.conversations.View()
|
||||
case shared.StateChat:
|
||||
return m.chat.View()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func Launch(ctx *lmcli.Context, convShortname string) error {
|
||||
p := tea.NewProgram(initialModel(ctx, shared.Values{ConvShortname: convShortname}), tea.WithAltScreen())
|
||||
if _, err := p.Run(); err != nil {
|
||||
return fmt.Errorf("Error running program: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/muesli/reflow/ansi"
|
||||
)
|
||||
|
||||
type MsgTempfileEditorClosed string
|
||||
|
||||
// OpenTempfileEditor opens $EDITOR on a temporary file with the given content.
|
||||
// Upon closing, the contents of the file are read and returned wrapped in a
|
||||
// MsgTempfileEditorClosed
|
||||
func OpenTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
|
||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||
|
||||
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
||||
if err != nil {
|
||||
return func() tea.Msg { return err }
|
||||
}
|
||||
|
||||
editor := os.Getenv("EDITOR")
|
||||
if editor == "" {
|
||||
editor = "vim"
|
||||
}
|
||||
|
||||
c := exec.Command(editor, msgFile.Name())
|
||||
return tea.ExecProcess(c, func(err error) tea.Msg {
|
||||
bytes, err := os.ReadFile(msgFile.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
os.Remove(msgFile.Name())
|
||||
fileContents := string(bytes)
|
||||
if strings.HasPrefix(fileContents, placeholder) {
|
||||
fileContents = fileContents[len(placeholder):]
|
||||
}
|
||||
stripped := strings.Trim(fileContents, "\n \t")
|
||||
return MsgTempfileEditorClosed(stripped)
|
||||
})
|
||||
}
|
||||
|
||||
// similar to lipgloss.Height, except returns 0 instead of 1 on empty strings
|
||||
func Height(str string) int {
|
||||
if str == "" {
|
||||
return 0
|
||||
}
|
||||
return strings.Count(str, "\n") + 1
|
||||
}
|
||||
|
||||
// truncate a string until its rendered cell width + the provided tail fits
|
||||
// within the given width
|
||||
func TruncateToCellWidth(str string, width int, tail string) string {
|
||||
cellWidth := ansi.PrintableRuneWidth(str)
|
||||
if cellWidth <= width {
|
||||
return str
|
||||
}
|
||||
tailWidth := ansi.PrintableRuneWidth(tail)
|
||||
for {
|
||||
str = str[:len(str)-((cellWidth+tailWidth)-width)]
|
||||
cellWidth = ansi.PrintableRuneWidth(str)
|
||||
if cellWidth+tailWidth <= max(width, 0) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return str + tail
|
||||
}
|
||||
|
||||
func ScrollIntoView(vp *viewport.Model, offset int, edge int) {
|
||||
currentOffset := vp.YOffset
|
||||
if offset >= currentOffset && offset < currentOffset+vp.Height {
|
||||
return
|
||||
}
|
||||
distance := currentOffset - offset
|
||||
if distance < 0 {
|
||||
// we should scroll down until it just comes into view
|
||||
vp.SetYOffset(currentOffset - (distance + (vp.Height - edge)) + 1)
|
||||
} else {
|
||||
// we should scroll up
|
||||
vp.SetYOffset(currentOffset - distance - edge)
|
||||
}
|
||||
}
|
||||
|
||||
func ErrorBanner(err error, width int) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
return lipgloss.NewStyle().
|
||||
Width(width).
|
||||
AlignHorizontal(lipgloss.Center).
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("1")).
|
||||
Render(fmt.Sprintf("%s", err))
|
||||
}
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
"github.com/charmbracelet/bubbles/cursor"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/textarea"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// custom tea.Msg types
|
||||
type (
|
||||
// sent when a conversation is (re)loaded
|
||||
msgConversationLoaded struct {
|
||||
conversation *api.Conversation
|
||||
rootMessages []api.Message
|
||||
}
|
||||
// sent when a new conversation title generated
|
||||
msgConversationTitleGenerated string
|
||||
// sent when the conversation has been persisted, triggers a reload of contents
|
||||
msgConversationPersisted struct {
|
||||
isNew bool
|
||||
conversation *api.Conversation
|
||||
messages []api.Message
|
||||
}
|
||||
// sent when a conversation's messages are laoded
|
||||
msgMessagesLoaded []api.Message
|
||||
// a special case of common.MsgError that stops the response waiting animation
|
||||
msgChatResponseError error
|
||||
// sent on each chunk received from LLM
|
||||
msgChatResponseChunk api.Chunk
|
||||
// sent on each completed reply
|
||||
msgChatResponse *api.Message
|
||||
// sent when the response is canceled
|
||||
msgChatResponseCanceled struct{}
|
||||
// sent when results from a tool call are returned
|
||||
msgToolResults []api.ToolResult
|
||||
// sent when the given message is made the new selected reply of its parent
|
||||
msgSelectedReplyCycled *api.Message
|
||||
// sent when the given message is made the new selected root of the current conversation
|
||||
msgSelectedRootCycled *api.Message
|
||||
// sent when a message's contents are updated and saved
|
||||
msgMessageUpdated *api.Message
|
||||
// sent when a message is cloned, with the cloned message
|
||||
msgMessageCloned *api.Message
|
||||
)
|
||||
|
||||
type focusState int
|
||||
|
||||
const (
|
||||
focusInput focusState = iota
|
||||
focusMessages
|
||||
)
|
||||
|
||||
type editorTarget int
|
||||
|
||||
const (
|
||||
input editorTarget = iota
|
||||
selectedMessage
|
||||
)
|
||||
|
||||
type state int
|
||||
|
||||
const (
|
||||
idle state = iota
|
||||
loading
|
||||
pendingResponse
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
shared.Shared
|
||||
shared.Sections
|
||||
|
||||
// app state
|
||||
state state // current overall status of the view
|
||||
conversation *api.Conversation
|
||||
rootMessages []api.Message
|
||||
messages []api.Message
|
||||
selectedMessage int
|
||||
editorTarget editorTarget
|
||||
stopSignal chan struct{}
|
||||
replyChan chan api.Message
|
||||
chatReplyChunks chan api.Chunk
|
||||
persistence bool // whether we will save new messages in the conversation
|
||||
|
||||
// ui state
|
||||
focus focusState
|
||||
wrap bool // whether message content is wrapped to viewport width
|
||||
showToolResults bool // whether tool calls and results are shown
|
||||
messageCache []string // cache of syntax highlighted and wrapped message content
|
||||
messageOffsets []int
|
||||
|
||||
// ui elements
|
||||
content viewport.Model
|
||||
input textarea.Model
|
||||
spinner spinner.Model
|
||||
replyCursor cursor.Model // cursor to indicate incoming response
|
||||
|
||||
// metrics
|
||||
tokenCount uint
|
||||
startTime time.Time
|
||||
elapsed time.Duration
|
||||
}
|
||||
|
||||
func Chat(shared shared.Shared) Model {
|
||||
m := Model{
|
||||
Shared: shared,
|
||||
|
||||
state: idle,
|
||||
conversation: &api.Conversation{},
|
||||
persistence: true,
|
||||
|
||||
stopSignal: make(chan struct{}),
|
||||
replyChan: make(chan api.Message),
|
||||
chatReplyChunks: make(chan api.Chunk),
|
||||
|
||||
wrap: true,
|
||||
selectedMessage: -1,
|
||||
|
||||
content: viewport.New(0, 0),
|
||||
input: textarea.New(),
|
||||
spinner: spinner.New(spinner.WithSpinner(
|
||||
spinner.Spinner{
|
||||
Frames: []string{
|
||||
". ",
|
||||
".. ",
|
||||
"...",
|
||||
".. ",
|
||||
". ",
|
||||
" ",
|
||||
},
|
||||
FPS: time.Second / 3,
|
||||
},
|
||||
)),
|
||||
replyCursor: cursor.New(),
|
||||
}
|
||||
|
||||
m.replyCursor.SetChar(" ")
|
||||
m.replyCursor.Focus()
|
||||
|
||||
system := shared.Ctx.DefaultSystemPrompt()
|
||||
|
||||
agent := shared.Ctx.GetAgent(shared.Ctx.Config.Defaults.Agent)
|
||||
if agent != nil && agent.SystemPrompt != "" {
|
||||
system = agent.SystemPrompt
|
||||
}
|
||||
|
||||
if system != "" {
|
||||
m.messages = api.ApplySystemPrompt(m.messages, system, false)
|
||||
}
|
||||
|
||||
m.input.Focus()
|
||||
m.input.MaxHeight = 0
|
||||
m.input.CharLimit = 0
|
||||
m.input.ShowLineNumbers = false
|
||||
m.input.Placeholder = "Enter a message"
|
||||
|
||||
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
||||
m.input.FocusedStyle.Base = inputFocusedStyle
|
||||
m.input.BlurredStyle.Base = inputBlurredStyle
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
m.waitForResponseChunk(),
|
||||
)
|
||||
}
|
|
@ -0,0 +1,308 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/agents"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func (m *Model) setMessage(i int, msg api.Message) {
|
||||
if i >= len(m.messages) {
|
||||
panic("i out of range")
|
||||
}
|
||||
m.messages[i] = msg
|
||||
m.messageCache[i] = m.renderMessage(i)
|
||||
}
|
||||
|
||||
func (m *Model) addMessage(msg api.Message) {
|
||||
m.messages = append(m.messages, msg)
|
||||
m.messageCache = append(m.messageCache, m.renderMessage(len(m.messages)-1))
|
||||
}
|
||||
|
||||
func (m *Model) setMessageContents(i int, content string) {
|
||||
if i >= len(m.messages) {
|
||||
panic("i out of range")
|
||||
}
|
||||
m.messages[i].Content = content
|
||||
m.messageCache[i] = m.renderMessage(i)
|
||||
}
|
||||
|
||||
func (m *Model) rebuildMessageCache() {
|
||||
m.messageCache = make([]string, len(m.messages))
|
||||
for i := range m.messages {
|
||||
m.messageCache[i] = m.renderMessage(i)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) updateContent() {
|
||||
atBottom := m.content.AtBottom()
|
||||
m.content.SetContent(m.conversationMessagesView())
|
||||
if atBottom {
|
||||
// if we were at bottom before the update, scroll with the output
|
||||
m.content.GotoBottom()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) loadConversation(shortname string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if shortname == "" {
|
||||
return nil
|
||||
}
|
||||
c, err := m.Shared.Ctx.Store.ConversationByShortName(shortname)
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not lookup conversation: %v", err))
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return shared.MsgError(fmt.Errorf("Conversation not found: %s", shortname))
|
||||
}
|
||||
rootMessages, err := m.Shared.Ctx.Store.RootMessages(c.ID)
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not load conversation root messages: %v\n", err))
|
||||
}
|
||||
return msgConversationLoaded{c, rootMessages}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) loadConversationMessages() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, err := m.Shared.Ctx.Store.PathToLeaf(m.conversation.SelectedRoot)
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||
}
|
||||
return msgMessagesLoaded(messages)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) generateConversationTitle() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
title, err := cmdutil.GenerateTitle(m.Shared.Ctx, m.messages)
|
||||
if err != nil {
|
||||
return shared.MsgError(err)
|
||||
}
|
||||
return msgConversationTitleGenerated(title)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) updateConversationTitle(conversation *api.Conversation) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.Shared.Ctx.Store.UpdateConversation(conversation)
|
||||
if err != nil {
|
||||
return shared.WrapError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Clones the given message (and its descendents). If selected is true, updates
|
||||
// either its parent's SelectedReply or its conversation's SelectedRoot to
|
||||
// point to the new clone
|
||||
func (m *Model) cloneMessage(message api.Message, selected bool) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
msg, _, err := m.Ctx.Store.CloneBranch(message)
|
||||
if err != nil {
|
||||
return shared.WrapError(fmt.Errorf("Could not clone message: %v", err))
|
||||
}
|
||||
if selected {
|
||||
if msg.Parent == nil {
|
||||
msg.Conversation.SelectedRoot = msg
|
||||
err = m.Shared.Ctx.Store.UpdateConversation(msg.Conversation)
|
||||
} else {
|
||||
msg.Parent.SelectedReply = msg
|
||||
err = m.Shared.Ctx.Store.UpdateMessage(msg.Parent)
|
||||
}
|
||||
if err != nil {
|
||||
return shared.WrapError(fmt.Errorf("Could not update selected message: %v", err))
|
||||
}
|
||||
}
|
||||
return msgMessageCloned(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) updateMessageContent(message *api.Message) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.Shared.Ctx.Store.UpdateMessage(message)
|
||||
if err != nil {
|
||||
return shared.WrapError(fmt.Errorf("Could not update message: %v", err))
|
||||
}
|
||||
return msgMessageUpdated(message)
|
||||
}
|
||||
}
|
||||
|
||||
func cycleSelectedMessage(selected *api.Message, choices []api.Message, dir MessageCycleDirection) (*api.Message, error) {
|
||||
currentIndex := -1
|
||||
for i, reply := range choices {
|
||||
if reply.ID == selected.ID {
|
||||
currentIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if currentIndex < 0 {
|
||||
// this should probably be an assert
|
||||
return nil, fmt.Errorf("Selected message %d not found in choices, this is a bug", selected.ID)
|
||||
}
|
||||
|
||||
var next int
|
||||
if dir == CyclePrev {
|
||||
// Wrap around to the last reply if at the beginning
|
||||
next = (currentIndex - 1 + len(choices)) % len(choices)
|
||||
} else {
|
||||
// Wrap around to the first reply if at the end
|
||||
next = (currentIndex + 1) % len(choices)
|
||||
}
|
||||
return &choices[next], nil
|
||||
}
|
||||
|
||||
func (m *Model) cycleSelectedRoot(conv *api.Conversation, dir MessageCycleDirection) tea.Cmd {
|
||||
if len(m.rootMessages) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func() tea.Msg {
|
||||
nextRoot, err := cycleSelectedMessage(conv.SelectedRoot, m.rootMessages, dir)
|
||||
if err != nil {
|
||||
return shared.WrapError(err)
|
||||
}
|
||||
|
||||
conv.SelectedRoot = nextRoot
|
||||
err = m.Shared.Ctx.Store.UpdateConversation(conv)
|
||||
if err != nil {
|
||||
return shared.WrapError(fmt.Errorf("Could not update conversation SelectedRoot: %v", err))
|
||||
}
|
||||
return msgSelectedRootCycled(nextRoot)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) cycleSelectedReply(message *api.Message, dir MessageCycleDirection) tea.Cmd {
|
||||
if len(message.Replies) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func() tea.Msg {
|
||||
nextReply, err := cycleSelectedMessage(message.SelectedReply, message.Replies, dir)
|
||||
if err != nil {
|
||||
return shared.WrapError(err)
|
||||
}
|
||||
|
||||
message.SelectedReply = nextReply
|
||||
err = m.Shared.Ctx.Store.UpdateMessage(message)
|
||||
if err != nil {
|
||||
return shared.WrapError(fmt.Errorf("Could not update message SelectedReply: %v", err))
|
||||
}
|
||||
return msgSelectedReplyCycled(nextReply)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) persistConversation() tea.Cmd {
|
||||
conversation := m.conversation
|
||||
messages := m.messages
|
||||
|
||||
var err error
|
||||
if conversation.ID == 0 {
|
||||
return func() tea.Msg {
|
||||
// Start a new conversation with all messages so far
|
||||
conversation, messages, err = m.Shared.Ctx.Store.StartConversation(messages...)
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not start new conversation: %v", err))
|
||||
}
|
||||
return msgConversationPersisted{true, conversation, messages}
|
||||
}
|
||||
}
|
||||
|
||||
return func() tea.Msg {
|
||||
// else, we'll handle updating an existing conversation's messages
|
||||
for i := range messages {
|
||||
if messages[i].ID > 0 {
|
||||
// message has an ID, update it
|
||||
err := m.Shared.Ctx.Store.UpdateMessage(&messages[i])
|
||||
if err != nil {
|
||||
return shared.MsgError(err)
|
||||
}
|
||||
} else if i > 0 {
|
||||
// messages is new, so add it as a reply to previous message
|
||||
saved, err := m.Shared.Ctx.Store.Reply(&messages[i-1], messages[i])
|
||||
if err != nil {
|
||||
return shared.MsgError(err)
|
||||
}
|
||||
messages[i] = saved[0]
|
||||
} else {
|
||||
// message has no id and no previous messages to add it to
|
||||
// this shouldn't happen?
|
||||
return fmt.Errorf("Error: no messages to reply to")
|
||||
}
|
||||
}
|
||||
return msgConversationPersisted{false, conversation, messages}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) executeToolCalls(toolCalls []api.ToolCall) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
|
||||
if agent == nil {
|
||||
return shared.MsgError(fmt.Errorf("Attempted to execute tool calls with no agent configured"))
|
||||
}
|
||||
|
||||
results, err := agents.ExecuteToolCalls(toolCalls, agent.Toolbox)
|
||||
if err != nil {
|
||||
return shared.MsgError(err)
|
||||
}
|
||||
return msgToolResults(results)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) promptLLM() tea.Cmd {
|
||||
m.state = pendingResponse
|
||||
m.replyCursor.Blink = false
|
||||
|
||||
m.startTime = time.Now()
|
||||
m.elapsed = 0
|
||||
m.tokenCount = 0
|
||||
|
||||
return func() tea.Msg {
|
||||
model, provider, err := m.Shared.Ctx.GetModelProvider(*m.Shared.Ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return shared.MsgError(err)
|
||||
}
|
||||
|
||||
params := api.RequestParameters{
|
||||
Model: model,
|
||||
MaxTokens: *m.Shared.Ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *m.Shared.Ctx.Config.Defaults.Temperature,
|
||||
}
|
||||
|
||||
agent := m.Shared.Ctx.GetAgent(m.Shared.Ctx.Config.Defaults.Agent)
|
||||
if agent != nil {
|
||||
params.Toolbox = agent.Toolbox
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-m.stopSignal:
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
resp, err := provider.CreateChatCompletionStream(
|
||||
ctx, params, m.messages, m.chatReplyChunks,
|
||||
)
|
||||
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return msgChatResponseCanceled(struct{}{})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return msgChatResponseError(err)
|
||||
}
|
||||
|
||||
return msgChatResponse(resp)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,180 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type MessageCycleDirection int
|
||||
|
||||
const (
|
||||
CycleNext MessageCycleDirection = 1
|
||||
CyclePrev MessageCycleDirection = -1
|
||||
)
|
||||
|
||||
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
switch m.focus {
|
||||
case focusInput:
|
||||
consumed, cmd := m.handleInputKey(msg)
|
||||
if consumed {
|
||||
return true, cmd
|
||||
}
|
||||
case focusMessages:
|
||||
consumed, cmd := m.handleMessagesKey(msg)
|
||||
if consumed {
|
||||
return true, cmd
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
if m.state == pendingResponse {
|
||||
m.stopSignal <- struct{}{}
|
||||
return true, nil
|
||||
}
|
||||
return true, func() tea.Msg {
|
||||
return shared.MsgViewChange(shared.StateConversations)
|
||||
}
|
||||
case "ctrl+c":
|
||||
if m.state == pendingResponse {
|
||||
m.stopSignal <- struct{}{}
|
||||
return true, nil
|
||||
}
|
||||
case "ctrl+p":
|
||||
m.persistence = !m.persistence
|
||||
return true, nil
|
||||
case "ctrl+t":
|
||||
m.showToolResults = !m.showToolResults
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
return true, nil
|
||||
case "ctrl+w":
|
||||
m.wrap = !m.wrap
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// handleMessagesKey handles input when the messages pane is focused
|
||||
func (m *Model) handleMessagesKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "tab", "enter":
|
||||
m.focus = focusInput
|
||||
m.updateContent()
|
||||
m.input.Focus()
|
||||
return true, nil
|
||||
case "e":
|
||||
if m.selectedMessage < len(m.messages) {
|
||||
m.editorTarget = selectedMessage
|
||||
return true, tuiutil.OpenTempfileEditor(
|
||||
"message.*.md",
|
||||
m.messages[m.selectedMessage].Content,
|
||||
"# Edit the message below\n",
|
||||
)
|
||||
}
|
||||
return false, nil
|
||||
case "ctrl+k":
|
||||
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
|
||||
m.selectedMessage--
|
||||
m.updateContent()
|
||||
offset := m.messageOffsets[m.selectedMessage]
|
||||
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
|
||||
}
|
||||
return true, nil
|
||||
case "ctrl+j":
|
||||
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
|
||||
m.selectedMessage++
|
||||
m.updateContent()
|
||||
offset := m.messageOffsets[m.selectedMessage]
|
||||
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
|
||||
}
|
||||
return true, nil
|
||||
case "ctrl+h", "ctrl+l":
|
||||
dir := CyclePrev
|
||||
if msg.String() == "ctrl+l" {
|
||||
dir = CycleNext
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
if m.selectedMessage == 0 {
|
||||
cmd = m.cycleSelectedRoot(m.conversation, dir)
|
||||
} else if m.selectedMessage > 0 {
|
||||
cmd = m.cycleSelectedReply(&m.messages[m.selectedMessage-1], dir)
|
||||
}
|
||||
|
||||
return cmd != nil, cmd
|
||||
case "ctrl+r":
|
||||
// resubmit the conversation with all messages up until and including the selected message
|
||||
if m.state == idle && m.selectedMessage < len(m.messages) {
|
||||
m.messages = m.messages[:m.selectedMessage+1]
|
||||
m.messageCache = m.messageCache[:m.selectedMessage+1]
|
||||
cmd := m.promptLLM()
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
return true, cmd
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// handleInputKey handles input when the input textarea is focused
|
||||
func (m *Model) handleInputKey(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.focus = focusMessages
|
||||
if len(m.messages) > 0 {
|
||||
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
|
||||
m.selectedMessage = len(m.messages) - 1
|
||||
}
|
||||
offset := m.messageOffsets[m.selectedMessage]
|
||||
tuiutil.ScrollIntoView(&m.content, offset, m.content.Height/2)
|
||||
}
|
||||
m.updateContent()
|
||||
m.input.Blur()
|
||||
return true, nil
|
||||
case "ctrl+s":
|
||||
// TODO: call a "handleSend" function which returns a tea.Cmd
|
||||
if m.state != idle {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
input := strings.TrimSpace(m.input.Value())
|
||||
if input == "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == api.MessageRoleUser {
|
||||
return true, shared.WrapError(fmt.Errorf("Can't reply to a user message"))
|
||||
}
|
||||
|
||||
m.addMessage(api.Message{
|
||||
Role: api.MessageRoleUser,
|
||||
Content: input,
|
||||
})
|
||||
|
||||
m.input.SetValue("")
|
||||
|
||||
var cmds []tea.Cmd
|
||||
if m.persistence {
|
||||
cmds = append(cmds, m.persistConversation())
|
||||
}
|
||||
|
||||
cmds = append(cmds, m.promptLLM())
|
||||
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
return true, tea.Batch(cmds...)
|
||||
case "ctrl+e":
|
||||
cmd := tuiutil.OpenTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
|
||||
m.editorTarget = input
|
||||
return true, cmd
|
||||
}
|
||||
return false, nil
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||
"github.com/charmbracelet/bubbles/cursor"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
func (m *Model) HandleResize(width, height int) {
|
||||
m.Width, m.Height = width, height
|
||||
m.content.Width = width
|
||||
m.input.SetWidth(width - m.input.FocusedStyle.Base.GetHorizontalFrameSize())
|
||||
if len(m.messages) > 0 {
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) waitForResponseChunk() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return msgChatResponseChunk(<-m.chatReplyChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.HandleResize(msg.Width, msg.Height)
|
||||
case shared.MsgViewEnter:
|
||||
// wake up spinners and cursors
|
||||
cmds = append(cmds, cursor.Blink, m.spinner.Tick)
|
||||
|
||||
if m.Shared.Values.ConvShortname != "" {
|
||||
// (re)load conversation contents
|
||||
cmds = append(cmds, m.loadConversation(m.Shared.Values.ConvShortname))
|
||||
|
||||
if m.conversation.ShortName.String != m.Shared.Values.ConvShortname {
|
||||
// clear existing messages if we're loading a new conversation
|
||||
m.messages = []api.Message{}
|
||||
m.selectedMessage = 0
|
||||
}
|
||||
}
|
||||
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
case tuiutil.MsgTempfileEditorClosed:
|
||||
contents := string(msg)
|
||||
switch m.editorTarget {
|
||||
case input:
|
||||
m.input.SetValue(contents)
|
||||
case selectedMessage:
|
||||
toEdit := m.messages[m.selectedMessage]
|
||||
if toEdit.Content != contents {
|
||||
toEdit.Content = contents
|
||||
m.setMessage(m.selectedMessage, toEdit)
|
||||
if m.persistence && toEdit.ID > 0 {
|
||||
// create clone of message with its new contents
|
||||
cmds = append(cmds, m.cloneMessage(toEdit, true))
|
||||
}
|
||||
}
|
||||
}
|
||||
case msgConversationLoaded:
|
||||
m.conversation = msg.conversation
|
||||
m.rootMessages = msg.rootMessages
|
||||
m.selectedMessage = -1
|
||||
if len(m.rootMessages) > 0 {
|
||||
cmds = append(cmds, m.loadConversationMessages())
|
||||
}
|
||||
case msgMessagesLoaded:
|
||||
m.messages = msg
|
||||
if m.selectedMessage == -1 {
|
||||
m.selectedMessage = len(msg) - 1
|
||||
} else {
|
||||
m.selectedMessage = min(m.selectedMessage, len(m.messages))
|
||||
}
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
case msgChatResponseChunk:
|
||||
cmds = append(cmds, m.waitForResponseChunk()) // wait for the next chunk
|
||||
|
||||
if msg.Content == "" {
|
||||
break
|
||||
}
|
||||
|
||||
last := len(m.messages) - 1
|
||||
if last >= 0 && m.messages[last].Role.IsAssistant() {
|
||||
// append chunk to existing message
|
||||
m.setMessageContents(last, m.messages[last].Content+msg.Content)
|
||||
} else {
|
||||
// use chunk in a new message
|
||||
m.addMessage(api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
m.updateContent()
|
||||
|
||||
// show cursor and reset blink interval (simulate typing)
|
||||
m.replyCursor.Blink = false
|
||||
cmds = append(cmds, m.replyCursor.BlinkCmd())
|
||||
|
||||
m.tokenCount += msg.TokenCount
|
||||
m.elapsed = time.Now().Sub(m.startTime)
|
||||
case msgChatResponse:
|
||||
m.state = idle
|
||||
|
||||
reply := (*api.Message)(msg)
|
||||
reply.Content = strings.TrimSpace(reply.Content)
|
||||
|
||||
last := len(m.messages) - 1
|
||||
if last < 0 {
|
||||
panic("Unexpected empty messages handling msgAssistantReply")
|
||||
}
|
||||
|
||||
if m.messages[last].Role.IsAssistant() {
|
||||
// TODO: handle continuations gracefully - some models support them well, others fail horribly.
|
||||
m.setMessage(last, *reply)
|
||||
} else {
|
||||
m.addMessage(*reply)
|
||||
}
|
||||
|
||||
switch reply.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
// TODO: user confirmation before execution
|
||||
// m.state = waitingForConfirmation
|
||||
cmds = append(cmds, m.executeToolCalls(reply.ToolCalls))
|
||||
}
|
||||
|
||||
if m.persistence {
|
||||
cmds = append(cmds, m.persistConversation())
|
||||
}
|
||||
|
||||
if m.conversation.Title == "" {
|
||||
cmds = append(cmds, m.generateConversationTitle())
|
||||
}
|
||||
|
||||
m.updateContent()
|
||||
case msgChatResponseCanceled:
|
||||
m.state = idle
|
||||
m.updateContent()
|
||||
case msgChatResponseError:
|
||||
m.state = idle
|
||||
m.Shared.Err = error(msg)
|
||||
m.updateContent()
|
||||
case msgToolResults:
|
||||
last := len(m.messages) - 1
|
||||
if last < 0 {
|
||||
panic("Unexpected empty messages handling msgAssistantReply")
|
||||
}
|
||||
|
||||
if m.messages[last].Role != api.MessageRoleToolCall {
|
||||
panic("Previous message not a tool call, unexpected")
|
||||
}
|
||||
|
||||
m.addMessage(api.Message{
|
||||
Role: api.MessageRoleToolResult,
|
||||
ToolResults: api.ToolResults(msg),
|
||||
})
|
||||
|
||||
if m.persistence {
|
||||
cmds = append(cmds, m.persistConversation())
|
||||
}
|
||||
|
||||
m.updateContent()
|
||||
case msgConversationTitleGenerated:
|
||||
title := string(msg)
|
||||
m.conversation.Title = title
|
||||
if m.persistence {
|
||||
cmds = append(cmds, m.updateConversationTitle(m.conversation))
|
||||
}
|
||||
case cursor.BlinkMsg:
|
||||
if m.state == pendingResponse {
|
||||
// ensure we show the updated "wait for response" cursor blink state
|
||||
last := len(m.messages)-1
|
||||
m.messageCache[last] = m.renderMessage(last)
|
||||
m.updateContent()
|
||||
}
|
||||
case msgConversationPersisted:
|
||||
m.conversation = msg.conversation
|
||||
m.messages = msg.messages
|
||||
if msg.isNew {
|
||||
m.rootMessages = []api.Message{m.messages[0]}
|
||||
}
|
||||
m.rebuildMessageCache()
|
||||
m.updateContent()
|
||||
case msgMessageCloned:
|
||||
if msg.Parent == nil {
|
||||
m.conversation = msg.Conversation
|
||||
m.rootMessages = append(m.rootMessages, *msg)
|
||||
}
|
||||
cmds = append(cmds, m.loadConversationMessages())
|
||||
case msgSelectedRootCycled, msgSelectedReplyCycled, msgMessageUpdated:
|
||||
cmds = append(cmds, m.loadConversationMessages())
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.spinner, cmd = m.spinner.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
m.replyCursor, cmd = m.replyCursor.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
prevInputLineCnt := m.input.LineCount()
|
||||
inputCaptured := false
|
||||
m.input, cmd = m.input.Update(msg)
|
||||
if cmd != nil {
|
||||
inputCaptured = true
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
if !inputCaptured {
|
||||
m.content, cmd = m.content.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// update views once window dimensions are known
|
||||
if m.Width > 0 {
|
||||
m.Header = m.headerView()
|
||||
m.Footer = m.footerView()
|
||||
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
|
||||
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
|
||||
|
||||
// calculate clamped input height to accomodate input text
|
||||
// minimum 4 lines, maximum half of content area
|
||||
newHeight := max(4, min((m.Height-fixedHeight-1)/2, m.input.LineCount()))
|
||||
m.input.SetHeight(newHeight)
|
||||
m.Input = m.input.View()
|
||||
|
||||
// remaining height towards content
|
||||
m.content.Height = m.Height - fixedHeight - tuiutil.Height(m.Input)
|
||||
m.Content = m.content.View()
|
||||
}
|
||||
|
||||
// this is a pretty nasty hack to ensure the input area viewport doesn't
|
||||
// scroll below its content, which can happen when the input viewport
|
||||
// height has grown, or previously entered lines have been deleted
|
||||
if prevInputLineCnt != m.input.LineCount() {
|
||||
// dist is the distance we'd need to scroll up from the current cursor
|
||||
// position to position the last input line at the bottom of the
|
||||
// viewport. if negative, we're already scrolled above the bottom
|
||||
dist := m.input.Line() - (m.input.LineCount() - m.input.Height())
|
||||
if dist > 0 {
|
||||
for i := 0; i < dist; i++ {
|
||||
// move cursor up until content reaches the bottom of the viewport
|
||||
m.input.CursorUp()
|
||||
}
|
||||
m.input, _ = m.input.Update(nil)
|
||||
for i := 0; i < dist; i++ {
|
||||
// move cursor back down to its previous position
|
||||
m.input.CursorDown()
|
||||
}
|
||||
m.input, _ = m.input.Update(nil)
|
||||
}
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
|
@ -0,0 +1,321 @@
|
|||
package chat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/muesli/reflow/wordwrap"
|
||||
"github.com/muesli/reflow/wrap"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// styles
|
||||
var (
|
||||
messageHeadingStyle = lipgloss.NewStyle().
|
||||
MarginTop(1).
|
||||
MarginBottom(1).
|
||||
PaddingLeft(1).
|
||||
Bold(true)
|
||||
|
||||
userStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("10"))
|
||||
|
||||
assistantStyle = lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("12"))
|
||||
|
||||
messageStyle = lipgloss.NewStyle().
|
||||
PaddingLeft(2).
|
||||
PaddingRight(2)
|
||||
|
||||
inputFocusedStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder(), true, true, true, false)
|
||||
|
||||
inputBlurredStyle = lipgloss.NewStyle().
|
||||
Faint(true).
|
||||
Border(lipgloss.RoundedBorder(), true, true, true, false)
|
||||
|
||||
footerStyle = lipgloss.NewStyle()
|
||||
)
|
||||
|
||||
func (m Model) View() string {
|
||||
if m.Width == 0 {
|
||||
return ""
|
||||
}
|
||||
sections := make([]string, 0, 6)
|
||||
|
||||
if m.Header != "" {
|
||||
sections = append(sections, m.Header)
|
||||
}
|
||||
|
||||
sections = append(sections, m.Content)
|
||||
if m.Error != "" {
|
||||
sections = append(sections, m.Error)
|
||||
}
|
||||
sections = append(sections, m.Input)
|
||||
|
||||
if m.Footer != "" {
|
||||
sections = append(sections, m.Footer)
|
||||
}
|
||||
|
||||
return lipgloss.JoinVertical(lipgloss.Left, sections...)
|
||||
}
|
||||
|
||||
func (m *Model) renderMessageHeading(i int, message *api.Message) string {
|
||||
icon := ""
|
||||
friendly := message.Role.FriendlyRole()
|
||||
style := lipgloss.NewStyle().Faint(true).Bold(true)
|
||||
|
||||
switch message.Role {
|
||||
case api.MessageRoleSystem:
|
||||
icon = "⚙️"
|
||||
case api.MessageRoleUser:
|
||||
style = userStyle
|
||||
case api.MessageRoleAssistant:
|
||||
style = assistantStyle
|
||||
case api.MessageRoleToolCall:
|
||||
style = assistantStyle
|
||||
friendly = api.MessageRoleAssistant.FriendlyRole()
|
||||
case api.MessageRoleToolResult:
|
||||
icon = "🔧"
|
||||
}
|
||||
|
||||
user := style.Render(icon + friendly)
|
||||
|
||||
var prefix string
|
||||
var suffix string
|
||||
|
||||
faint := lipgloss.NewStyle().Faint(true)
|
||||
|
||||
if i == 0 && len(m.rootMessages) > 1 && m.conversation.SelectedRootID != nil {
|
||||
selectedRootIndex := 0
|
||||
for j, reply := range m.rootMessages {
|
||||
if reply.ID == *m.conversation.SelectedRootID {
|
||||
selectedRootIndex = j
|
||||
break
|
||||
}
|
||||
}
|
||||
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedRootIndex+1, len(m.rootMessages)))
|
||||
}
|
||||
if i > 0 && len(m.messages[i-1].Replies) > 1 {
|
||||
// Find the selected reply index
|
||||
selectedReplyIndex := 0
|
||||
for j, reply := range m.messages[i-1].Replies {
|
||||
if reply.ID == *m.messages[i-1].SelectedReplyID {
|
||||
selectedReplyIndex = j
|
||||
break
|
||||
}
|
||||
}
|
||||
suffix += faint.Render(fmt.Sprintf(" <%d/%d>", selectedReplyIndex+1, len(m.messages[i-1].Replies)))
|
||||
}
|
||||
|
||||
if m.focus == focusMessages {
|
||||
if i == m.selectedMessage {
|
||||
prefix = "> "
|
||||
}
|
||||
}
|
||||
|
||||
if message.ID == 0 {
|
||||
suffix += faint.Render(" (not saved)")
|
||||
}
|
||||
|
||||
return messageHeadingStyle.Render(prefix + user + suffix)
|
||||
}
|
||||
|
||||
// renderMessages renders the message at the given index as it should be shown
|
||||
// *at this moment* - we render differently depending on the current application
|
||||
// state (window size, etc, etc).
|
||||
func (m *Model) renderMessage(i int) string {
|
||||
msg := &m.messages[i]
|
||||
|
||||
// Write message contents
|
||||
sb := &strings.Builder{}
|
||||
sb.Grow(len(msg.Content) * 2)
|
||||
if msg.Content != "" {
|
||||
err := m.Shared.Ctx.Chroma.Highlight(sb, msg.Content)
|
||||
if err != nil {
|
||||
sb.Reset()
|
||||
sb.WriteString(msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
isLast := i == len(m.messages)-1
|
||||
isAssistant := msg.Role == api.MessageRoleAssistant
|
||||
|
||||
if m.state == pendingResponse && isLast && isAssistant {
|
||||
// Show the assistant's cursor
|
||||
sb.WriteString(m.replyCursor.View())
|
||||
}
|
||||
|
||||
// Write tool call info
|
||||
var toolString string
|
||||
switch msg.Role {
|
||||
case api.MessageRoleToolCall:
|
||||
bytes, err := yaml.Marshal(msg.ToolCalls)
|
||||
if err != nil {
|
||||
toolString = "Could not serialize ToolCalls"
|
||||
} else {
|
||||
toolString = "tool_calls:\n" + string(bytes)
|
||||
}
|
||||
case api.MessageRoleToolResult:
|
||||
type renderedResult struct {
|
||||
ToolName string `yaml:"tool"`
|
||||
Result any `yaml:"result,omitempty"`
|
||||
}
|
||||
|
||||
var toolResults []renderedResult
|
||||
for _, result := range msg.ToolResults {
|
||||
if m.showToolResults {
|
||||
var jsonResult interface{}
|
||||
err := json.Unmarshal([]byte(result.Result), &jsonResult)
|
||||
if err != nil {
|
||||
// If parsing as JSON fails, treat Result as a plain string
|
||||
toolResults = append(toolResults, renderedResult{
|
||||
ToolName: result.ToolName,
|
||||
Result: result.Result,
|
||||
})
|
||||
} else {
|
||||
// If parsing as JSON succeeds, marshal the parsed JSON into YAML
|
||||
toolResults = append(toolResults, renderedResult{
|
||||
ToolName: result.ToolName,
|
||||
Result: &jsonResult,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Only show the tool name when results are hidden
|
||||
toolResults = append(toolResults, renderedResult{
|
||||
ToolName: result.ToolName,
|
||||
Result: "(hidden, press ctrl+t to view)",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
bytes, err := yaml.Marshal(toolResults)
|
||||
if err != nil {
|
||||
toolString = "Could not serialize ToolResults"
|
||||
} else {
|
||||
toolString = "tool_results:\n" + string(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
if toolString != "" {
|
||||
toolString = strings.TrimRight(toolString, "\n")
|
||||
if msg.Content != "" {
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
_ = m.Shared.Ctx.Chroma.HighlightLang(sb, toolString, "yaml")
|
||||
}
|
||||
|
||||
content := strings.TrimRight(sb.String(), "\n")
|
||||
|
||||
if m.wrap {
|
||||
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding()
|
||||
// first we word-wrap text to slightly less than desired width (since
|
||||
// wordwrap seems to have an off-by-1 issue), then hard wrap at
|
||||
// desired with
|
||||
content = wrap.String(wordwrap.String(content, wrapWidth-2), wrapWidth)
|
||||
}
|
||||
|
||||
return messageStyle.Width(0).Render(content)
|
||||
}
|
||||
|
||||
// render the conversation into a string
|
||||
func (m *Model) conversationMessagesView() string {
|
||||
sb := strings.Builder{}
|
||||
|
||||
m.messageOffsets = make([]int, len(m.messages))
|
||||
lineCnt := 1
|
||||
for i, message := range m.messages {
|
||||
m.messageOffsets[i] = lineCnt
|
||||
|
||||
heading := m.renderMessageHeading(i, &message)
|
||||
sb.WriteString(heading)
|
||||
sb.WriteString("\n")
|
||||
lineCnt += lipgloss.Height(heading)
|
||||
|
||||
rendered := m.messageCache[i]
|
||||
sb.WriteString(rendered)
|
||||
sb.WriteString("\n")
|
||||
lineCnt += lipgloss.Height(rendered)
|
||||
}
|
||||
|
||||
// Render a placeholder for the incoming assistant reply
|
||||
if m.state == pendingResponse && m.messages[len(m.messages)-1].Role != api.MessageRoleAssistant {
|
||||
heading := m.renderMessageHeading(-1, &api.Message{
|
||||
Role: api.MessageRoleAssistant,
|
||||
})
|
||||
sb.WriteString(heading)
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(messageStyle.Width(0).Render(m.replyCursor.View()))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m *Model) headerView() string {
|
||||
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||
var title string
|
||||
if m.conversation != nil && m.conversation.Title != "" {
|
||||
title = m.conversation.Title
|
||||
} else {
|
||||
title = "Untitled"
|
||||
}
|
||||
title = tuiutil.TruncateToCellWidth(title, m.Width-styles.Header.GetHorizontalPadding(), "...")
|
||||
header := titleStyle.Render(title)
|
||||
return styles.Header.Width(m.Width).Render(header)
|
||||
}
|
||||
|
||||
func (m *Model) footerView() string {
|
||||
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
|
||||
segmentSeparator := "|"
|
||||
|
||||
savingStyle := segmentStyle.Copy().Bold(true)
|
||||
saving := ""
|
||||
if m.persistence {
|
||||
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
|
||||
} else {
|
||||
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
|
||||
}
|
||||
|
||||
var status string
|
||||
switch m.state {
|
||||
case pendingResponse:
|
||||
status = "Press ctrl+c to cancel" + m.spinner.View()
|
||||
default:
|
||||
status = "Press ctrl+s to send"
|
||||
}
|
||||
|
||||
leftSegments := []string{
|
||||
saving,
|
||||
segmentStyle.Render(status),
|
||||
}
|
||||
rightSegments := []string{}
|
||||
|
||||
if m.elapsed > 0 && m.tokenCount > 0 {
|
||||
throughput := fmt.Sprintf("%.0f t/sec", float64(m.tokenCount)/m.elapsed.Seconds())
|
||||
rightSegments = append(rightSegments, segmentStyle.Render(throughput))
|
||||
}
|
||||
|
||||
model := fmt.Sprintf("Model: %s", *m.Shared.Ctx.Config.Defaults.Model)
|
||||
rightSegments = append(rightSegments, segmentStyle.Render(model))
|
||||
|
||||
left := strings.Join(leftSegments, segmentSeparator)
|
||||
right := strings.Join(rightSegments, segmentSeparator)
|
||||
|
||||
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
|
||||
remaining := m.Width - totalWidth
|
||||
|
||||
var padding string
|
||||
if remaining > 0 {
|
||||
padding = strings.Repeat(" ", remaining)
|
||||
}
|
||||
|
||||
footer := left + padding + right
|
||||
if remaining < 0 {
|
||||
footer = tuiutil.TruncateToCellWidth(footer, m.Width, "...")
|
||||
}
|
||||
return footerStyle.Width(m.Width).Render(footer)
|
||||
}
|
|
@ -0,0 +1,342 @@
|
|||
package conversations
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/api"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/bubbles"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/shared"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui/styles"
|
||||
tuiutil "git.mlow.ca/mlow/lmcli/pkg/tui/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
type loadedConversation struct {
|
||||
conv api.Conversation
|
||||
lastReply api.Message
|
||||
}
|
||||
|
||||
type (
|
||||
// sent when conversation list is loaded
|
||||
msgConversationsLoaded ([]loadedConversation)
|
||||
// sent when a conversation is selected
|
||||
msgConversationSelected api.Conversation
|
||||
// sent when a conversation is deleted
|
||||
msgConversationDeleted struct{}
|
||||
)
|
||||
|
||||
// Prompt payloads
|
||||
type (
|
||||
deleteConversationPayload api.Conversation
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
shared.Shared
|
||||
shared.Sections
|
||||
|
||||
conversations []loadedConversation
|
||||
cursor int // index of the currently selected conversation
|
||||
itemOffsets []int // keeps track of the viewport y offset of each rendered item
|
||||
|
||||
content viewport.Model
|
||||
|
||||
confirmPrompt bubbles.ConfirmPrompt
|
||||
}
|
||||
|
||||
func Conversations(shared shared.Shared) Model {
|
||||
m := Model{
|
||||
Shared: shared,
|
||||
content: viewport.New(0, 0),
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Model) HandleInput(msg tea.KeyMsg) (bool, tea.Cmd) {
|
||||
if m.confirmPrompt.Focused() {
|
||||
var cmd tea.Cmd
|
||||
m.confirmPrompt, cmd = m.confirmPrompt.Update(msg)
|
||||
if cmd != nil {
|
||||
return true, cmd
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "enter":
|
||||
if len(m.conversations) > 0 && m.cursor < len(m.conversations) {
|
||||
return true, func() tea.Msg {
|
||||
return msgConversationSelected(m.conversations[m.cursor].conv)
|
||||
}
|
||||
}
|
||||
case "j", "down":
|
||||
if m.cursor < len(m.conversations)-1 {
|
||||
m.cursor++
|
||||
if m.cursor == len(m.conversations)-1 {
|
||||
// if last conversation, simply scroll to the bottom
|
||||
m.content.GotoBottom()
|
||||
} else {
|
||||
// this hack positions the *next* conversatoin slightly
|
||||
// *off* the screen, ensuring the entire m.cursor is shown,
|
||||
// even if its height may not be constant due to wrapping.
|
||||
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor+1], -1)
|
||||
}
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
} else {
|
||||
m.cursor = len(m.conversations) - 1
|
||||
m.content.GotoBottom()
|
||||
}
|
||||
return true, nil
|
||||
case "k", "up":
|
||||
if m.cursor > 0 {
|
||||
m.cursor--
|
||||
if m.cursor == 0 {
|
||||
m.content.GotoTop()
|
||||
} else {
|
||||
tuiutil.ScrollIntoView(&m.content, m.itemOffsets[m.cursor], 1)
|
||||
}
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
} else {
|
||||
m.cursor = 0
|
||||
m.content.GotoTop()
|
||||
}
|
||||
return true, nil
|
||||
case "n":
|
||||
// new conversation
|
||||
case "d":
|
||||
if !m.confirmPrompt.Focused() && len(m.conversations) > 0 && m.cursor < len(m.conversations) {
|
||||
title := m.conversations[m.cursor].conv.Title
|
||||
if title == "" {
|
||||
title = "(untitled)"
|
||||
}
|
||||
m.confirmPrompt = bubbles.NewConfirmPrompt(
|
||||
fmt.Sprintf("Delete '%s'?", title),
|
||||
deleteConversationPayload(m.conversations[m.cursor].conv),
|
||||
)
|
||||
m.confirmPrompt.Style = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("3"))
|
||||
return true, nil
|
||||
}
|
||||
case "c":
|
||||
// copy/clone conversation
|
||||
case "r":
|
||||
// show prompt to rename conversation
|
||||
case "shift+r":
|
||||
// show prompt to generate name for conversation
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m Model) Init() tea.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) HandleResize(width, height int) {
|
||||
m.Width, m.Height = width, height
|
||||
m.content.Width = width
|
||||
}
|
||||
|
||||
func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
switch msg := msg.(type) {
|
||||
case shared.MsgViewEnter:
|
||||
cmds = append(cmds, m.loadConversations())
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
case tea.WindowSizeMsg:
|
||||
m.HandleResize(msg.Width, msg.Height)
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
case msgConversationsLoaded:
|
||||
m.conversations = msg
|
||||
m.cursor = max(0, min(len(m.conversations), m.cursor))
|
||||
m.content.SetContent(m.renderConversationList())
|
||||
case msgConversationSelected:
|
||||
m.Values.ConvShortname = msg.ShortName.String
|
||||
cmds = append(cmds, func() tea.Msg {
|
||||
return shared.MsgViewChange(shared.StateChat)
|
||||
})
|
||||
case bubbles.MsgConfirmPromptAnswered:
|
||||
m.confirmPrompt.Blur()
|
||||
if msg.Value {
|
||||
switch payload := msg.Payload.(type) {
|
||||
case deleteConversationPayload:
|
||||
cmds = append(cmds, m.deleteConversation(api.Conversation(payload)))
|
||||
}
|
||||
}
|
||||
case msgConversationDeleted:
|
||||
cmds = append(cmds, m.loadConversations())
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.content, cmd = m.content.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
if m.Width > 0 {
|
||||
wrap := lipgloss.NewStyle().Width(m.Width)
|
||||
m.Header = m.headerView()
|
||||
m.Footer = "" // TODO: "Press ? for help"
|
||||
if m.confirmPrompt.Focused() {
|
||||
m.Footer = wrap.Render(m.confirmPrompt.View())
|
||||
}
|
||||
m.Error = tuiutil.ErrorBanner(m.Err, m.Width)
|
||||
fixedHeight := tuiutil.Height(m.Header) + tuiutil.Height(m.Error) + tuiutil.Height(m.Footer)
|
||||
m.content.Height = m.Height - fixedHeight
|
||||
m.Content = m.content.View()
|
||||
}
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *Model) loadConversations() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, err := m.Ctx.Store.LatestConversationMessages()
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not load conversations: %v", err))
|
||||
}
|
||||
|
||||
loaded := make([]loadedConversation, len(messages))
|
||||
for i, m := range messages {
|
||||
loaded[i].lastReply = m
|
||||
loaded[i].conv = *m.Conversation
|
||||
}
|
||||
|
||||
return msgConversationsLoaded(loaded)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) deleteConversation(conv api.Conversation) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.Ctx.Store.DeleteConversation(&conv)
|
||||
if err != nil {
|
||||
return shared.MsgError(fmt.Errorf("Could not delete conversation: %v", err))
|
||||
}
|
||||
return msgConversationDeleted{}
|
||||
}
|
||||
}
|
||||
|
||||
func (m Model) View() string {
|
||||
if m.Width == 0 {
|
||||
return ""
|
||||
}
|
||||
sections := make([]string, 0, 6)
|
||||
|
||||
if m.Header != "" {
|
||||
sections = append(sections, m.Header)
|
||||
}
|
||||
|
||||
sections = append(sections, m.Content)
|
||||
if m.Error != "" {
|
||||
sections = append(sections, m.Error)
|
||||
}
|
||||
|
||||
if m.Footer != "" {
|
||||
sections = append(sections, m.Footer)
|
||||
}
|
||||
|
||||
return lipgloss.JoinVertical(lipgloss.Left, sections...)
|
||||
}
|
||||
|
||||
func (m *Model) headerView() string {
|
||||
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||
header := titleStyle.Render("Conversations")
|
||||
return styles.Header.Width(m.Width).Render(header)
|
||||
}
|
||||
|
||||
func (m *Model) renderConversationList() string {
|
||||
now := time.Now()
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
dayOfWeek := int(now.Weekday())
|
||||
categories := []struct {
|
||||
name string
|
||||
cutoff time.Duration
|
||||
}{
|
||||
{"Today", now.Sub(midnight)},
|
||||
{"Yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||
{"This week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||
{"Last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||
{"This month", now.Sub(monthStart)},
|
||||
{"Last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||
{"2 Months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||
{"3 Months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||
{"4 Months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||
{"5 Months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||
{"6 Months ago", now.Sub(monthStart.AddDate(0, -6, 0))},
|
||||
{"Older", now.Sub(time.Time{})},
|
||||
}
|
||||
|
||||
categoryStyle := lipgloss.NewStyle().
|
||||
MarginBottom(1).
|
||||
Foreground(lipgloss.Color("12")).
|
||||
PaddingLeft(1).
|
||||
Bold(true)
|
||||
|
||||
itemStyle := lipgloss.NewStyle().
|
||||
MarginBottom(1)
|
||||
|
||||
ageStyle := lipgloss.NewStyle().Faint(true).SetString()
|
||||
titleStyle := lipgloss.NewStyle().Bold(true)
|
||||
untitledStyle := lipgloss.NewStyle().Faint(true).Italic(true)
|
||||
selectedStyle := lipgloss.NewStyle().Faint(true).Foreground(lipgloss.Color("6"))
|
||||
|
||||
var (
|
||||
currentOffset int
|
||||
currentCategory string
|
||||
sb strings.Builder
|
||||
)
|
||||
|
||||
m.itemOffsets = make([]int, len(m.conversations))
|
||||
sb.WriteRune('\n')
|
||||
currentOffset += 1
|
||||
|
||||
for i, c := range m.conversations {
|
||||
lastReplyAge := now.Sub(c.lastReply.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, g := range categories {
|
||||
if lastReplyAge < g.cutoff {
|
||||
category = g.name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// print the category
|
||||
if category != currentCategory {
|
||||
currentCategory = category
|
||||
heading := categoryStyle.Render(currentCategory)
|
||||
sb.WriteString(heading)
|
||||
currentOffset += tuiutil.Height(heading)
|
||||
sb.WriteRune('\n')
|
||||
}
|
||||
|
||||
tStyle := titleStyle.Copy()
|
||||
if c.conv.Title == "" {
|
||||
tStyle = tStyle.Inherit(untitledStyle).SetString("(untitled)")
|
||||
}
|
||||
if i == m.cursor {
|
||||
tStyle = tStyle.Inherit(selectedStyle)
|
||||
}
|
||||
|
||||
title := tStyle.Width(m.Width - 3).PaddingLeft(2).Render(c.conv.Title)
|
||||
if i == m.cursor {
|
||||
title = ">" + title[1:]
|
||||
}
|
||||
|
||||
m.itemOffsets[i] = currentOffset
|
||||
item := itemStyle.Render(fmt.Sprintf(
|
||||
"%s\n %s",
|
||||
title,
|
||||
ageStyle.Render(util.HumanTimeElapsedSince(lastReplyAge)),
|
||||
))
|
||||
sb.WriteString(item)
|
||||
currentOffset += tuiutil.Height(item)
|
||||
if i < len(m.conversations)-1 {
|
||||
sb.WriteRune('\n')
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
package tty
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/chroma/v2"
|
||||
"github.com/alecthomas/chroma/v2/formatters"
|
||||
"github.com/alecthomas/chroma/v2/lexers"
|
||||
"github.com/alecthomas/chroma/v2/styles"
|
||||
)
|
||||
|
||||
type ChromaHighlighter struct {
|
||||
lexer chroma.Lexer
|
||||
formatter chroma.Formatter
|
||||
style *chroma.Style
|
||||
}
|
||||
|
||||
func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter {
|
||||
l := lexers.Get(lang)
|
||||
if l == nil {
|
||||
l = lexers.Fallback
|
||||
}
|
||||
l = chroma.Coalesce(l)
|
||||
|
||||
f := formatters.Get(format)
|
||||
if f == nil {
|
||||
f = formatters.Fallback
|
||||
}
|
||||
|
||||
s := styles.Get(style)
|
||||
if s == nil {
|
||||
s = styles.Fallback
|
||||
}
|
||||
|
||||
return &ChromaHighlighter{
|
||||
lexer: l,
|
||||
formatter: f,
|
||||
style: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.formatter.Format(w, s.style, it)
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.Grow(len(text) * 2)
|
||||
s.formatter.Format(&sb, s.style, it)
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) HighlightLang(w io.Writer, text string, lang string) (error) {
|
||||
l := lexers.Get(lang)
|
||||
if l == nil {
|
||||
l = lexers.Fallback
|
||||
}
|
||||
l = chroma.Coalesce(l)
|
||||
old := s.lexer
|
||||
s.lexer = l
|
||||
err := s.Highlight(w, text)
|
||||
s.lexer = old
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) HighlightLangS(text string, lang string) (string, error) {
|
||||
l := lexers.Get(lang)
|
||||
if l == nil {
|
||||
l = lexers.Fallback
|
||||
}
|
||||
l = chroma.Coalesce(l)
|
||||
old := s.lexer
|
||||
s.lexer = l
|
||||
highlighted, err := s.HighlightS(text)
|
||||
s.lexer = old
|
||||
return highlighted, err
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package cli
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -17,11 +17,11 @@ import (
|
|||
// contents of the file exactly match the value of placeholder (no edits to the
|
||||
// file were made), then an empty string is returned. Otherwise, the contents
|
||||
// are returned. Example patten: message.*.md
|
||||
func InputFromEditor(placeholder string, pattern string) (string, error) {
|
||||
func InputFromEditor(placeholder string, pattern string, content string) (string, error) {
|
||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||
defer os.Remove(msgFile.Name())
|
||||
|
||||
os.WriteFile(msgFile.Name(), []byte(placeholder), os.ModeAppend)
|
||||
os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
||||
|
||||
editor := os.Getenv("EDITOR")
|
||||
if editor == "" {
|
||||
|
@ -38,7 +38,7 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
|
|||
}
|
||||
|
||||
bytes, _ := os.ReadFile(msgFile.Name())
|
||||
content := string(bytes)
|
||||
content = string(bytes)
|
||||
|
||||
if placeholder != "" {
|
||||
if content == placeholder {
|
||||
|
@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
|
|||
|
||||
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
||||
// of the given duration.
|
||||
func humanTimeElapsedSince(d time.Duration) string {
|
||||
func HumanTimeElapsedSince(d time.Duration) string {
|
||||
seconds := d.Seconds()
|
||||
minutes := seconds / 60
|
||||
hours := minutes / 60
|
||||
|
@ -137,8 +137,8 @@ func SetStructDefaults(data interface{}) bool {
|
|||
}
|
||||
|
||||
// Get the "default" struct tag
|
||||
defaultTag := v.Type().Field(i).Tag.Get("default")
|
||||
if defaultTag == "" {
|
||||
defaultTag, ok := v.Type().Field(i).Tag.Lookup("default")
|
||||
if (!ok) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -147,10 +147,18 @@ func SetStructDefaults(data interface{}) bool {
|
|||
case reflect.String:
|
||||
defaultValue := defaultTag
|
||||
field.Set(reflect.ValueOf(&defaultValue))
|
||||
case reflect.Uint, reflect.Uint32, reflect.Uint64:
|
||||
intValue, _ := strconv.ParseUint(defaultTag, 10, e.Bits())
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetUint(intValue)
|
||||
case reflect.Int, reflect.Int32, reflect.Int64:
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, e.Bits())
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetInt(intValue)
|
||||
case reflect.Float32, reflect.Float64:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, e.Bits())
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Bool:
|
||||
boolValue := defaultTag == "true"
|
||||
field.Set(reflect.ValueOf(&boolValue))
|
||||
|
@ -160,10 +168,8 @@ func SetStructDefaults(data interface{}) bool {
|
|||
return changed
|
||||
}
|
||||
|
||||
// FileContents returns the string contents of the given file.
|
||||
// TODO: we should support retrieving the content (or an approximation of)
|
||||
// non-text documents, e.g. PDFs.
|
||||
func FileContents(file string) (string, error) {
|
||||
// ReadFileContents returns the string contents of the given file.
|
||||
func ReadFileContents(file string) (string, error) {
|
||||
path := filepath.Clean(file)
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
Loading…
Reference in New Issue