Compare commits
77 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3185b2d7d6 | |||
| 6c64f21d9a | |||
| 6f737ad19c | |||
| a8ffdc156a | |||
| 7a974d9764 | |||
| adb61ffa59 | |||
| 1c7ad75fd5 | |||
| 613aa1a552 | |||
| 71833b89cd | |||
| 2ad93394b1 | |||
| f49b772960 | |||
| 29d8138dc0 | |||
| 3756f6d9e4 | |||
| 41916eb7b3 | |||
| 3892e68251 | |||
| 8697284064 | |||
| 383d34f311 | |||
| ac0e380244 | |||
| c3a3cb0181 | |||
| 612ea90417 | |||
| 94508b1dbf | |||
| 7e002e5214 | |||
| 48e4dea3cf | |||
| 0ab552303d | |||
| 6ce42a77f9 | |||
| 2cb1a0005d | |||
| ea78edf039 | |||
| 793aaab50e | |||
| 5afc9667c7 | |||
| dfafc573e5 | |||
| 97f81a0cbb | |||
| eca120cde6 | |||
| 12d4e495d4 | |||
| d8c8262890 | |||
| 758f74aba5 | |||
| 1570c23d63 | |||
| 46149e0b67 | |||
| c2c61e2aaa | |||
| 5e880d3b31 | |||
| 62f07dd240 | |||
| ec1f326c2a | |||
| db116660a5 | |||
| 32eab7aa35 | |||
| 91d3c9c2e1 | |||
| 8bdb155bf7 | |||
| 045146bb5c | |||
| 2c7bdd8ebf | |||
| 7d56726c78 | |||
| f2c7d2bdd0 | |||
| 0a27b9a8d3 | |||
| 2611663168 | |||
| 120e61e88b | |||
| fa966d30db | |||
| 51ce74ad3a | |||
| b93ee94233 | |||
| db788760a3 | |||
| 242ed886ec | |||
| 02a23b9035 | |||
| b3913d0027 | |||
| 1184f9aaae | |||
| a25d0d95e8 | |||
| becaa5c7c0 | |||
| 239ded18f3 | |||
| 59e78669c8 | |||
| 1966ec881b | |||
| f6ded3e20e | |||
| 1e8ff60c54 | |||
| af2fccd4ee | |||
| f206334e72 | |||
| 5615051637 | |||
| c46500de4e | |||
| d5dde10dbf | |||
| d32e9421fe | |||
| e29dbaf2a3 | |||
| c64bc370f4 | |||
| 4f37ed046b | |||
| ed6ee9bea9 |
38
README.md
38
README.md
@@ -4,17 +4,39 @@
|
||||
|
||||
Current features:
|
||||
- Perform one-shot prompts with `lmcli prompt <message>`
|
||||
- Manage persistent conversations with the `new`, `reply`, `view`, and `rm`
|
||||
sub-commands.
|
||||
- Manage persistent conversations with the `new`, `reply`, `view`, `rm`,
|
||||
`edit`, `retry`, `continue` sub-commands.
|
||||
- Syntax highlighted output
|
||||
|
||||
Planned features:
|
||||
- Ask questions about content received on stdin
|
||||
- "functions" to allow reading (and possibly writing) to files within the
|
||||
current working directory
|
||||
- Tool calling, see the [Tools](#tools) section.
|
||||
|
||||
Maybe features:
|
||||
- Natural language image generation, iterative editing
|
||||
- Chat-like interface (`lmcli chat`) for rapid back-and-forth conversations
|
||||
- Support for additional models/APIs besides just OpenAI
|
||||
|
||||
## Tools
|
||||
Tools must be explicitly enabled by adding the tool's name to the
|
||||
`openai.enabledTools` array in `config.yaml`.
|
||||
|
||||
Note: all filesystem related tools operate relative to the current directory
|
||||
only. They do not accept absolute paths, and efforts are made to ensure they
|
||||
cannot escape above the working directory). **Close attention must be paid to
|
||||
where you are running `lmcli`, as the model could at any time decide to use one
|
||||
of these tools to discover and read potentially sensitive information from your
|
||||
filesystem.**
|
||||
|
||||
It's best to only have tools enabled in `config.yaml` when you intend to be
|
||||
using them, since their descriptions (see `pkg/cli/functions.go`) count towards
|
||||
context usage.
|
||||
|
||||
Available tools:
|
||||
|
||||
- `read_dir` - Read the contents of a directory.
|
||||
- `read_file` - Read the contents of a file.
|
||||
- `write_file` - Write contents to a file.
|
||||
- `file_insert_lines` - Insert lines at a position within a file. Tricky for
|
||||
the model to use, but can potentially save tokens.
|
||||
- `file_replace_lines` - Remove or replace a range of lines within a file. Even
|
||||
trickier for the model to use.
|
||||
|
||||
## Install
|
||||
|
||||
|
||||
20
go.mod
20
go.mod
@@ -4,8 +4,10 @@ go 1.21
|
||||
|
||||
require (
|
||||
github.com/alecthomas/chroma/v2 v2.11.1
|
||||
github.com/charmbracelet/bubbles v0.18.0
|
||||
github.com/charmbracelet/bubbletea v0.25.0
|
||||
github.com/charmbracelet/lipgloss v0.10.0
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible
|
||||
github.com/gookit/color v1.5.4
|
||||
github.com/sashabaranov/go-openai v1.17.7
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/sqids/sqids-go v0.4.1
|
||||
@@ -14,15 +16,29 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.18 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.18 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.15.2 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.14.0 // indirect
|
||||
golang.org/x/term v0.6.0 // indirect
|
||||
golang.org/x/text v0.3.8 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
|
||||
gopkg.in/yaml.v2 v2.2.2 // indirect
|
||||
)
|
||||
|
||||
52
go.sum
52
go.sum
@@ -4,16 +4,24 @@ github.com/alecthomas/chroma/v2 v2.11.1 h1:m9uUtgcdAwgfFNxuqj7AIG75jD2YmL61BBIJW
|
||||
github.com/alecthomas/chroma/v2 v2.11.1/go.mod h1:4TQu7gdfuPjSh76j78ietmqh9LiurGF0EpseFXdKMBw=
|
||||
github.com/alecthomas/repr v0.2.0 h1:HAzS41CIzNW5syS8Mf9UwXhNH1J9aix/BvDRf1Ml2Yk=
|
||||
github.com/alecthomas/repr v0.2.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/charmbracelet/bubbles v0.18.0 h1:PYv1A036luoBGroX6VWjQIE9Syf2Wby2oOl/39KLfy0=
|
||||
github.com/charmbracelet/bubbles v0.18.0/go.mod h1:08qhZhtIwzgrtBjAcJnij1t1H0ZRjwHyGsy6AL11PSw=
|
||||
github.com/charmbracelet/bubbletea v0.25.0 h1:bAfwk7jRz7FKFl9RzlIULPkStffg5k6pNt5dywy4TcM=
|
||||
github.com/charmbracelet/bubbletea v0.25.0/go.mod h1:EN3QDR1T5ZdWmdfDzYcqOCAps45+QIJbLOBxmVNWNNg=
|
||||
github.com/charmbracelet/lipgloss v0.10.0 h1:KWeXFSexGcfahHX+54URiZGkBFazf70JNMtwg/AFW3s=
|
||||
github.com/charmbracelet/lipgloss v0.10.0/go.mod h1:Wig9DSfvANsxqkRsqj6x87irdy123SR4dOXlKa91ciE=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81 h1:q2hJAaP1k2wIvVRd/hEHD7lacgqrCPS+k8g1MndzfWY=
|
||||
github.com/containerd/console v1.0.4-0.20230313162750-1ae8d489ac81/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o=
|
||||
github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0=
|
||||
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
|
||||
github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
@@ -26,11 +34,30 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp98=
|
||||
github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
|
||||
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
|
||||
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
|
||||
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b h1:1XF24mVaiu7u+CFywTdcDo2ie1pzzhwjt6RHqzpMU34=
|
||||
github.com/muesli/ansi v0.0.0-20211018074035-2e021307bc4b/go.mod h1:fQuZ0gauxyBcmsdE3ZT4NasjaRdxmbCS0jRHsrWu3Ho=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
|
||||
github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
|
||||
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
@@ -42,18 +69,21 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw=
|
||||
github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8=
|
||||
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
|
||||
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/sqlite v1.5.4 h1:IqXwXi8M/ZlPzH/947tn5uik3aYQslP9BVveoax0nV0=
|
||||
gorm.io/driver/sqlite v1.5.4/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
|
||||
|
||||
17
main.go
17
main.go
@@ -1,15 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := cli.Execute(); err != nil {
|
||||
fmt.Fprint(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
ctx, err := lmcli.NewContext()
|
||||
if err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
root := cmd.RootCmd(ctx)
|
||||
if err := root.Execute(); err != nil {
|
||||
lmcli.Fatal("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
var config *Config
|
||||
var store *Store
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
|
||||
config, err = NewConfig()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
store, err = NewStore()
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
||||
550
pkg/cli/cmd.go
550
pkg/cli/cmd.go
@@ -1,550 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
maxTokens int
|
||||
model string
|
||||
systemPrompt string
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
func init() {
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().IntVar(&maxTokens, "length", *config.OpenAI.DefaultMaxLength, "Max response length in tokens")
|
||||
cmd.Flags().StringVar(&model, "model", *config.OpenAI.DefaultModel, "The language model to use")
|
||||
cmd.Flags().StringVar(&systemPrompt, "system-prompt", *config.ModelDefaults.SystemPrompt, "The system prompt to use.")
|
||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file whose contents are used as the system prompt.")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
rootCmd.AddCommand(
|
||||
continueCmd,
|
||||
lsCmd,
|
||||
newCmd,
|
||||
promptCmd,
|
||||
replyCmd,
|
||||
retryCmd,
|
||||
rmCmd,
|
||||
viewCmd,
|
||||
)
|
||||
}
|
||||
|
||||
func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
func SystemPrompt() string {
|
||||
if systemPromptFile != "" {
|
||||
content, err := FileContents(systemPromptFile)
|
||||
if err != nil {
|
||||
Fatal("Could not read file contents at %s: %v", systemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
|
||||
// LLMRequest prompts the LLM with the given Message, writes the (partial)
|
||||
// response to stdout, and returns the (partial) response or any errors.
|
||||
func LLMRequest(messages []Message) (string, error) {
|
||||
// receiver receives the reponse from LLM
|
||||
receiver := make(chan string)
|
||||
defer close(receiver)
|
||||
|
||||
// start HandleDelayedContent goroutine to print received data to stdout
|
||||
go HandleDelayedContent(receiver)
|
||||
|
||||
response, err := CreateChatCompletionStream(model, messages, maxTokens, receiver)
|
||||
if response != "" {
|
||||
if err != nil {
|
||||
Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
// InputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func InputFromArgsOrEditor(args []string, placeholder string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = InputFromEditor(placeholder, "message.*.md")
|
||||
if err != nil {
|
||||
Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Trim(strings.Join(args, " "), " \t\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Use: "lmcli",
|
||||
Short: "Interact with Large Language Models",
|
||||
Long: `lmcli is a CLI tool to interact with Large Language Models.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
// execute `lm ls` by default
|
||||
},
|
||||
}
|
||||
|
||||
var lsCmd = &cobra.Command{
|
||||
Use: "ls",
|
||||
Short: "List existing conversations",
|
||||
Long: `List all existing conversations in descending order of recent activity.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
conversations, err := store.Conversations()
|
||||
if err != nil {
|
||||
fmt.Println("Could not fetch conversations.")
|
||||
return
|
||||
}
|
||||
|
||||
// Example output
|
||||
// $ lmcli ls
|
||||
// last hour:
|
||||
// 98sg - 12 minutes ago - Project discussion
|
||||
// last day:
|
||||
// tj3l - 10 hours ago - Deep learning concepts
|
||||
// last week:
|
||||
// bwfm - 2 days ago - Machine learning study
|
||||
// 8n3h - 3 days ago - Weekend plans
|
||||
// f3n7 - 6 days ago - CLI development
|
||||
// last month:
|
||||
// 5hn2 - 8 days ago - Book club discussion
|
||||
// b7ze - 20 days ago - Gardening tips and tricks
|
||||
// last 6 months:
|
||||
// 3jn2 - 30 days ago - Web development best practices
|
||||
// 43jk - 2 months ago - Longboard maintenance
|
||||
// g8d9 - 3 months ago - History book club
|
||||
// 4lk3 - 4 months ago - Local events and meetups
|
||||
// 43jn - 6 months ago - Mobile photography techniques
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
categories := []string{
|
||||
"recent",
|
||||
"last hour",
|
||||
"last 6 hours",
|
||||
"last day",
|
||||
"last week",
|
||||
"last month",
|
||||
"last 6 months",
|
||||
"older",
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
switch {
|
||||
case messageAge <= 10*time.Minute:
|
||||
category = "recent"
|
||||
case messageAge <= time.Hour:
|
||||
category = "last hour"
|
||||
case messageAge <= 6*time.Hour:
|
||||
category = "last 6 hours"
|
||||
case messageAge <= 24*time.Hour:
|
||||
category = "last day"
|
||||
case messageAge <= 7*24*time.Hour:
|
||||
category = "last week"
|
||||
case messageAge <= 30*24*time.Hour:
|
||||
category = "last month"
|
||||
case messageAge <= 6*30*24*time.Hour: // Approximate as 6 months
|
||||
category = "last 6 months"
|
||||
default:
|
||||
category = "older"
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
conversation.ShortName.String,
|
||||
humanTimeElapsedSince(messageAge),
|
||||
conversation.Title,
|
||||
)
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
for _, category := range categories {
|
||||
conversations, ok := categorized[category]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
slices.SortFunc(conversations, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
fmt.Printf("%s:\n", category)
|
||||
for _, conv := range conversations {
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var rmCmd = &cobra.Command{
|
||||
Use: "rm <conversation>",
|
||||
Short: "Remove a conversation",
|
||||
Long: `Removes a conversation by its short name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
Fatal("Could not search for conversation: %v\n", err)
|
||||
}
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
err = store.DeleteConversation(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not delete conversation: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var viewCmd = &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
RenderConversation(messages, false)
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var replyCmd = &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Send a reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
messageContents := InputFromArgsOrEditor(args[1:], "# How would you like to reply?\n")
|
||||
if messageContents == "" {
|
||||
Fatal("No reply was provided.\n")
|
||||
}
|
||||
|
||||
userReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
}
|
||||
|
||||
err = store.SaveMessage(&userReply)
|
||||
if err != nil {
|
||||
Warn("Could not save your reply: %v\n", err)
|
||||
}
|
||||
|
||||
messages = append(messages, userReply)
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var newCmd = &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
messageContents := InputFromArgsOrEditor(args, "# What would you like to say?\n")
|
||||
if messageContents == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
conversation := Conversation{}
|
||||
err := store.SaveConversation(&conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not save new conversation: %v\n", err)
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "system",
|
||||
OriginalContent: SystemPrompt(),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "user",
|
||||
OriginalContent: messageContents,
|
||||
},
|
||||
}
|
||||
for _, message := range messages {
|
||||
err = store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
reply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
reply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
reply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
Fatal("Could not save reply: %v\n", err)
|
||||
}
|
||||
|
||||
err = conversation.GenerateTitle()
|
||||
if err != nil {
|
||||
Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
err = store.SaveConversation(&conversation)
|
||||
if err != nil {
|
||||
Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var promptCmd = &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
message := InputFromArgsOrEditor(args, "# What would you like to say?\n")
|
||||
if message == "" {
|
||||
Fatal("No message was provided.\n")
|
||||
}
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "system",
|
||||
OriginalContent: SystemPrompt(),
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
var retryCmd = &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retries the last conversation prompt.",
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
var lastUserMessageIndex int
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
lastUserMessageIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
messages = messages[:lastUserMessageIndex+1]
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
var continueCmd = &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continues where the previous prompt left off.",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
shortName := args[0]
|
||||
conversation, err := store.ConversationByShortName(shortName)
|
||||
if conversation.ID == 0 {
|
||||
Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
|
||||
messages, err := store.Messages(conversation)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation: %s\n", conversation.Title)
|
||||
}
|
||||
|
||||
RenderConversation(messages, true)
|
||||
assistantReply := Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: "assistant",
|
||||
}
|
||||
assistantReply.RenderTTY()
|
||||
|
||||
response, err := LLMRequest(messages)
|
||||
if err != nil {
|
||||
Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
|
||||
assistantReply.OriginalContent = response
|
||||
|
||||
err = store.SaveMessage(&assistantReply)
|
||||
if err != nil {
|
||||
Fatal("Could not save assistant reply: %v\n", err)
|
||||
}
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/go-yaml/yaml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ModelDefaults *struct {
|
||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||
} `yaml:"modelDefaults"`
|
||||
OpenAI *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
DefaultModel *string `yaml:"defaultModel" default:"gpt-4"`
|
||||
DefaultMaxLength *int `yaml:"defaultMaxLength" default:"256"`
|
||||
} `yaml:"openai"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
}
|
||||
|
||||
func ConfigDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func NewConfig() (*Config, error) {
|
||||
configFile := filepath.Join(ConfigDir(), "config.yaml")
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
configBytes, err := os.ReadFile(configFile)
|
||||
if os.IsNotExist(err) {
|
||||
shouldWriteDefaults = true
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||
} else {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
}
|
||||
|
||||
shouldWriteDefaults = SetStructDefaults(c)
|
||||
if shouldWriteDefaults {
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
}
|
||||
bytes, _ := yaml.Marshal(c)
|
||||
_, err = file.Write(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *Message) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch m.Role {
|
||||
case "user":
|
||||
friendlyRole = "You"
|
||||
case "system":
|
||||
friendlyRole = "System"
|
||||
case "assistant":
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = m.Role
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
|
||||
func (c *Conversation) GenerateTitle() error {
|
||||
const header = "Generate a consise 4-5 word title for the conversation below."
|
||||
prompt := fmt.Sprintf("%s\n\n---\n\n%s", header, c.FormatForExternalPrompting())
|
||||
|
||||
messages := []Message{
|
||||
{
|
||||
Role: "user",
|
||||
OriginalContent: prompt,
|
||||
},
|
||||
}
|
||||
|
||||
model := "gpt-3.5-turbo" // use cheap model to generate title
|
||||
response, err := CreateChatCompletion(model, messages, 25)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.Title = response
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conversation) FormatForExternalPrompting() string {
|
||||
sb := strings.Builder{}
|
||||
messages, err := store.Messages(c)
|
||||
if err != nil {
|
||||
Fatal("Could not retrieve messages for conversation %v", c)
|
||||
}
|
||||
for _, message := range messages {
|
||||
sb.WriteString(fmt.Sprintf("<%s>\n", message.FriendlyRole()))
|
||||
sb.WriteString(fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n", message.OriginalContent))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,582 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type FunctionResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameter struct {
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type FunctionParameters struct {
|
||||
Type string `json:"type"` // "object"
|
||||
Properties map[string]FunctionParameter `json:"properties"`
|
||||
Required []string `json:"required,omitempty"` // required function parameter names
|
||||
}
|
||||
|
||||
type AvailableTool struct {
|
||||
openai.Tool
|
||||
// The tool's implementation. Returns a string, as tool call results
|
||||
// are treated as normal messages with string contents.
|
||||
Impl func(arguments map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
const (
|
||||
READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Results are returned as JSON in the following format:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
// result may be an empty array if there are no files in the directory
|
||||
"result": [
|
||||
{"name": "a_file", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
... // more files or directories
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size (in bytes) of the file.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the file is prefixed with its line number and a tabs (\t) to make
|
||||
it make it easier to see which lines to change for other modifications.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Note: only use this tool when you've been explicitly asked to create or write to a file.
|
||||
|
||||
When using this function, you do not need to share the content you intend to write with the user first.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success", // if successful, or a different message indicating failure
|
||||
}`
|
||||
|
||||
FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Be cautious with your edits. When replacing, ensure the replacement content matches the flow and indentation of surrounding content.`
|
||||
)
|
||||
|
||||
var AvailableTools = map[string]AvailableTool{
|
||||
"read_dir": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"relative_dir": {
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return ReadDir(relativeDir), nil
|
||||
},
|
||||
},
|
||||
"read_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
},
|
||||
},
|
||||
Required: []string{"path"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
return ReadFile(path), nil
|
||||
},
|
||||
},
|
||||
"write_file": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
return WriteFile(path, content), nil
|
||||
},
|
||||
},
|
||||
"file_insert_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"position": {
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "position", "content"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
return FileInsertLines(path, position, content), nil
|
||||
},
|
||||
},
|
||||
"file_replace_lines": {
|
||||
Tool: openai.Tool{Type: "function", Function: openai.FunctionDefinition{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: FunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]FunctionParameter{
|
||||
"path": {
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
},
|
||||
"start_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
},
|
||||
"end_line": {
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
"content": {
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Required: []string{"path", "start_line"},
|
||||
},
|
||||
}},
|
||||
Impl: func(args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return FileReplaceLines(path, start_line, end_line, content), nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func resultToJson(result FunctionResult) string {
|
||||
if result.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
result.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not marshal FunctionResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes)
|
||||
}
|
||||
|
||||
// ExecuteToolCalls handles the execution of all tool_calls provided, and
|
||||
// returns their results formatted as []Message(s) with role: 'tool' and.
|
||||
func ExecuteToolCalls(toolCalls []openai.ToolCall) ([]Message, error) {
|
||||
var toolResults []Message
|
||||
for _, toolCall := range toolCalls {
|
||||
if toolCall.Type != "function" {
|
||||
// unsupported tool type
|
||||
continue
|
||||
}
|
||||
|
||||
tool, ok := AvailableTools[toolCall.Function.Name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Function.Name)
|
||||
}
|
||||
|
||||
var functionArgs map[string]interface{}
|
||||
err := json.Unmarshal([]byte(toolCall.Function.Arguments), &functionArgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not unmarshal tool arguments. Malformed JSON? Error: %v", err)
|
||||
}
|
||||
|
||||
// TODO: ability to silence this
|
||||
fmt.Fprintf(os.Stderr, "INFO: Executing tool '%s' with args %s\n", toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := tool.Impl(functionArgs)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Function.Name, err)
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, Message{
|
||||
Role: "tool",
|
||||
OriginalContent: toolResult,
|
||||
ToolCallID: sql.NullString{String: toolCall.ID, Valid: true},
|
||||
// name is not required since the introduction of ToolCallID
|
||||
// hypothesis: by setting it, we inform the model of what a
|
||||
// function's purpose was if future requests omit the function
|
||||
// definition
|
||||
})
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func isPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isPathWithinCWD(path string) (bool, *FunctionResult) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: "Failed to determine current working directory"}
|
||||
}
|
||||
if ok, err := isPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())}
|
||||
}
|
||||
return false, &FunctionResult{Message: fmt.Sprintf("Path '%s' is not within the current working directory", path)}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ReadDir(path string) string {
|
||||
// TODO(?): implement whitelist - list of directories which model is allowed to work in
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: dirContents})
|
||||
}
|
||||
|
||||
func ReadFile(path string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{
|
||||
Result: content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func WriteFile(path string, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
return resultToJson(FunctionResult{})
|
||||
}
|
||||
|
||||
func FileInsertLines(path string, position int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
}
|
||||
|
||||
func FileReplaceLines(path string, startLine int, endLine int, content string) string {
|
||||
ok, res := isPathWithinCWD(path)
|
||||
if !ok {
|
||||
return resultToJson(*res)
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())})
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())})
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return resultToJson(FunctionResult{Message: "start_line cannot be less than 1"})
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return resultToJson(FunctionResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())})
|
||||
}
|
||||
|
||||
return resultToJson(FunctionResult{Result: newContent})
|
||||
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func CreateChatCompletionRequest(model string, messages []Message, maxTokens int) openai.ChatCompletionRequest {
|
||||
chatCompletionMessages := []openai.ChatCompletionMessage{}
|
||||
for _, m := range messages {
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: m.Role,
|
||||
Content: m.OriginalContent,
|
||||
}
|
||||
if m.ToolCallID.Valid {
|
||||
message.ToolCallID = m.ToolCallID.String
|
||||
}
|
||||
if m.ToolCalls.Valid {
|
||||
// unmarshal directly into chatMessage.ToolCalls
|
||||
err := json.Unmarshal([]byte(m.ToolCalls.String), &message.ToolCalls)
|
||||
if err != nil {
|
||||
// TODO: handle, this shouldn't really happen since
|
||||
// we only save the successfully marshal'd data to database
|
||||
fmt.Printf("Error unmarshalling the tool_calls JSON: %v\n", err)
|
||||
}
|
||||
}
|
||||
chatCompletionMessages = append(chatCompletionMessages, message)
|
||||
}
|
||||
|
||||
var tools []openai.Tool
|
||||
for _, t := range AvailableTools {
|
||||
// TODO: support some way to limit which tools are available per-request
|
||||
tools = append(tools, t.Tool)
|
||||
}
|
||||
|
||||
return openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: chatCompletionMessages,
|
||||
MaxTokens: maxTokens,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
Tools: tools,
|
||||
ToolChoice: "auto", // TODO: allow limiting/forcing which function is called?
|
||||
}
|
||||
}
|
||||
|
||||
// CreateChatCompletion submits a Chat Completion API request and returns the
|
||||
// response. CreateChatCompletion will recursively call itself in the case of
|
||||
// tool calls, until a response is received with the final user-facing output.
|
||||
func CreateChatCompletion(model string, messages []Message, maxTokens int) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
resp, err := client.CreateChatCompletion(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
if choice.Message.Content != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(choice.Message.ToolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(choice.Message.ToolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies added
|
||||
// to the original messages
|
||||
return CreateChatCompletion(model, append(messages, toolReplies...), maxTokens)
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
// CreateChatCompletionStream submits a streaming Chat Completion API request
|
||||
// and both returns and streams the response to the provided output channel.
|
||||
// May return a partial response if an error occurs mid-stream.
|
||||
func CreateChatCompletionStream(model string, messages []Message, maxTokens int, output chan<- string) (string, error) {
|
||||
client := openai.NewClient(*config.OpenAI.APIKey)
|
||||
req := CreateChatCompletionRequest(model, messages, maxTokens)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(context.Background(), req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
if content.String() != "" {
|
||||
return "", fmt.Errorf("Model replied with user-facing content in addition to tool calls. Unsupported.")
|
||||
}
|
||||
|
||||
// Append the assistant's reply with its request for tool calls
|
||||
toolCallJson, _ := json.Marshal(toolCalls)
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
ToolCalls: sql.NullString{String: string(toolCallJson), Valid: true},
|
||||
})
|
||||
|
||||
toolReplies, err := ExecuteToolCalls(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
return CreateChatCompletionStream(model, append(messages, toolReplies...), maxTokens, output)
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
||||
133
pkg/cli/store.go
133
pkg/cli/store.go
@@ -1,133 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Conversation Conversation
|
||||
OriginalContent string
|
||||
Role string // one of: 'user', 'assistant', 'tool'
|
||||
CreatedAt time.Time
|
||||
ToolCallID sql.NullString
|
||||
ToolCalls sql.NullString // a json-encoded array of tool calls from the model
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
}
|
||||
|
||||
func DataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func NewStore() (*Store, error) {
|
||||
databaseFile := filepath.Join(DataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
|
||||
models := []any{
|
||||
&Conversation{},
|
||||
&Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &Store{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *Store) SaveConversation(conversation *Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteConversation(conversation *Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *Store) SaveMessage(message *Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *Store) Conversations() ([]Conversation, error) {
|
||||
var conversations []Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *Store) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *Store) ConversationByShortName(shortName string) (*Conversation, error) {
|
||||
var conversation Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *Store) Messages(conversation *Conversation) ([]Message, error) {
|
||||
var messages []Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *Store) LastMessage(conversation *Conversation) (*Message, error) {
|
||||
var message Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
113
pkg/cli/tty.go
113
pkg/cli/tty.go
@@ -1,113 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/alecthomas/chroma/v2/quick"
|
||||
"github.com/gookit/color"
|
||||
)
|
||||
|
||||
// ShowWaitAnimation "draws" an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// noftify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
fmt.Print("\r")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\r")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func HandleDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(messages []Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
message.RenderTTY()
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HighlightMarkdown applies syntax highlighting to the provided markdown text
|
||||
// and writes it to stdout.
|
||||
func HighlightMarkdown(markdownText string) error {
|
||||
return quick.Highlight(os.Stdout, markdownText, "md", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
}
|
||||
|
||||
func (m *Message) RenderTTY() {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = humanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
var roleStyle color.Style
|
||||
switch m.Role {
|
||||
case "system":
|
||||
roleStyle = color.Style{color.HiRed}
|
||||
case "user":
|
||||
roleStyle = color.Style{color.HiGreen}
|
||||
case "assistant":
|
||||
roleStyle = color.Style{color.HiBlue}
|
||||
default:
|
||||
roleStyle = color.Style{color.FgWhite}
|
||||
}
|
||||
roleStyle.Add(color.Bold)
|
||||
|
||||
headerColor := color.FgYellow
|
||||
separator := headerColor.Sprint("===")
|
||||
timestamp := headerColor.Sprint(messageAge)
|
||||
role := roleStyle.Sprint(m.FriendlyRole())
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.OriginalContent != "" {
|
||||
HighlightMarkdown(m.OriginalContent)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
37
pkg/cmd/chat.go
Normal file
37
pkg/cmd/chat.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ChatCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "chat [conversation]",
|
||||
Short: "Open the chat interface",
|
||||
Long: `Open the chat interface, optionally on a given conversation.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// TODO: implement jump-to-conversation logic
|
||||
shortname := ""
|
||||
if len(args) == 1 {
|
||||
shortname = args[0]
|
||||
}
|
||||
err := tui.Launch(ctx, shortname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
72
pkg/cmd/clone.go
Normal file
72
pkg/cmd/clone.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func CloneCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "clone <conversation>",
|
||||
Short: "Clone conversations",
|
||||
Long: `Clones the provided conversation.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
toClone, err := cmdutil.LookupConversationE(ctx, shortName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
messagesToCopy, err := ctx.Store.Messages(toClone)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", toClone.ShortName.String)
|
||||
}
|
||||
|
||||
clone := &model.Conversation{
|
||||
Title: toClone.Title + " - Clone",
|
||||
}
|
||||
if err := ctx.Store.SaveConversation(clone); err != nil {
|
||||
return fmt.Errorf("Cloud not create clone: %s", err)
|
||||
}
|
||||
|
||||
var errors []error
|
||||
messageCnt := 0
|
||||
for _, message := range messagesToCopy {
|
||||
newMessage := message
|
||||
newMessage.ConversationID = clone.ID
|
||||
newMessage.ID = 0
|
||||
if err := ctx.Store.SaveMessage(&newMessage); err != nil {
|
||||
errors = append(errors, err)
|
||||
} else {
|
||||
messageCnt++
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Messages failed to be cloned: %v", errors)
|
||||
}
|
||||
|
||||
fmt.Printf("Cloned %d messages to: %s\n", messageCnt, clone.Title)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
96
pkg/cmd/cmd.go
Normal file
96
pkg/cmd/cmd.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
systemPromptFile string
|
||||
)
|
||||
|
||||
func RootCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
var root = &cobra.Command{
|
||||
Use: "lmcli <command> [flags]",
|
||||
Long: `lmcli - Large Language Model CLI`,
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Usage()
|
||||
},
|
||||
}
|
||||
|
||||
chatCmd := ChatCmd(ctx)
|
||||
continueCmd := ContinueCmd(ctx)
|
||||
cloneCmd := CloneCmd(ctx)
|
||||
editCmd := EditCmd(ctx)
|
||||
listCmd := ListCmd(ctx)
|
||||
newCmd := NewCmd(ctx)
|
||||
promptCmd := PromptCmd(ctx)
|
||||
renameCmd := RenameCmd(ctx)
|
||||
replyCmd := ReplyCmd(ctx)
|
||||
retryCmd := RetryCmd(ctx)
|
||||
rmCmd := RemoveCmd(ctx)
|
||||
viewCmd := ViewCmd(ctx)
|
||||
|
||||
inputCmds := []*cobra.Command{newCmd, promptCmd, replyCmd, retryCmd, continueCmd, editCmd}
|
||||
for _, cmd := range inputCmds {
|
||||
cmd.Flags().StringVar(ctx.Config.Defaults.Model, "model", *ctx.Config.Defaults.Model, "Which model to use")
|
||||
cmd.RegisterFlagCompletionFunc("model", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) {
|
||||
return ctx.GetModels(), cobra.ShellCompDirectiveDefault
|
||||
})
|
||||
cmd.Flags().IntVar(ctx.Config.Defaults.MaxTokens, "length", *ctx.Config.Defaults.MaxTokens, "Maximum response tokens")
|
||||
cmd.Flags().StringVar(ctx.Config.Defaults.SystemPrompt, "system-prompt", *ctx.Config.Defaults.SystemPrompt, "System prompt")
|
||||
cmd.Flags().StringVar(&systemPromptFile, "system-prompt-file", "", "A path to a file containing the system prompt")
|
||||
cmd.MarkFlagsMutuallyExclusive("system-prompt", "system-prompt-file")
|
||||
}
|
||||
|
||||
root.AddCommand(
|
||||
chatCmd,
|
||||
cloneCmd,
|
||||
continueCmd,
|
||||
editCmd,
|
||||
listCmd,
|
||||
newCmd,
|
||||
promptCmd,
|
||||
renameCmd,
|
||||
replyCmd,
|
||||
retryCmd,
|
||||
rmCmd,
|
||||
viewCmd,
|
||||
)
|
||||
|
||||
return root
|
||||
}
|
||||
|
||||
func getSystemPrompt(ctx *lmcli.Context) string {
|
||||
if systemPromptFile != "" {
|
||||
content, err := util.ReadFileContents(systemPromptFile)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not read file contents at %s: %v\n", systemPromptFile, err)
|
||||
}
|
||||
return content
|
||||
}
|
||||
return *ctx.Config.Defaults.SystemPrompt
|
||||
}
|
||||
|
||||
// inputFromArgsOrEditor returns either the provided input from the args slice
|
||||
// (joined with spaces), or if len(args) is 0, opens an editor and returns
|
||||
// whatever input was provided there. placeholder is a string which populates
|
||||
// the editor and gets stripped from the final output.
|
||||
func inputFromArgsOrEditor(args []string, placeholder string, existingMessage string) (message string) {
|
||||
var err error
|
||||
if len(args) == 0 {
|
||||
message, err = util.InputFromEditor(placeholder, "message.*.md", existingMessage)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Failed to get input: %v\n", err)
|
||||
}
|
||||
} else {
|
||||
message = strings.Join(args, " ")
|
||||
}
|
||||
message = strings.Trim(message, " \t\n")
|
||||
return
|
||||
}
|
||||
72
pkg/cmd/continue.go
Normal file
72
pkg/cmd/continue.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ContinueCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "continue <conversation>",
|
||||
Short: "Continue a conversation from the last message",
|
||||
Long: `Re-prompt the conversation with all existing prompts. Useful if a reply was cut short.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not retrieve conversation messages: %v", err)
|
||||
}
|
||||
|
||||
if len(messages) < 2 {
|
||||
return fmt.Errorf("conversation expected to have at least 2 messages")
|
||||
}
|
||||
|
||||
lastMessage := &messages[len(messages)-1]
|
||||
if lastMessage.Role != model.MessageRoleAssistant {
|
||||
return fmt.Errorf("the last message in the conversation is not an assistant message")
|
||||
}
|
||||
|
||||
// Output the contents of the last message so far
|
||||
fmt.Print(lastMessage.Content)
|
||||
|
||||
// Submit the LLM request, allowing it to continue the last message
|
||||
continuedOutput, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error fetching LLM response: %v", err)
|
||||
}
|
||||
|
||||
// Append the new response to the original message
|
||||
lastMessage.Content += strings.TrimRight(continuedOutput, "\n\t ")
|
||||
|
||||
// Update the original message
|
||||
err = ctx.Store.UpdateMessage(lastMessage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not update the last message: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
100
pkg/cmd/edit.go
Normal file
100
pkg/cmd/edit.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func EditCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "edit <conversation>",
|
||||
Short: "Edit the last user reply in a conversation",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
offset, _ := cmd.Flags().GetInt("offset")
|
||||
if offset < 0 {
|
||||
offset = -offset
|
||||
}
|
||||
|
||||
if offset > len(messages)-1 {
|
||||
return fmt.Errorf("Offset %d is before the start of the conversation.", offset)
|
||||
}
|
||||
|
||||
desiredIdx := len(messages) - 1 - offset
|
||||
|
||||
// walk backwards through the conversation deleting messages until and
|
||||
// including the last user message
|
||||
toRemove := []model.Message{}
|
||||
var toEdit *model.Message
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if i == desiredIdx {
|
||||
toEdit = &messages[i]
|
||||
}
|
||||
toRemove = append(toRemove, messages[i])
|
||||
messages = messages[:i]
|
||||
if toEdit != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
newContents := inputFromArgsOrEditor(args[1:], "# Save when finished editing\n", toEdit.Content)
|
||||
switch newContents {
|
||||
case toEdit.Content:
|
||||
return fmt.Errorf("No edits were made.")
|
||||
case "":
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
role, _ := cmd.Flags().GetString("role")
|
||||
if role == "" {
|
||||
role = string(toEdit.Role)
|
||||
} else if role != string(model.MessageRoleUser) && role != string(model.MessageRoleAssistant) {
|
||||
return fmt.Errorf("Invalid role specified. Please use 'user' or 'assistant'.")
|
||||
}
|
||||
|
||||
for _, message := range toRemove {
|
||||
err = ctx.Store.DeleteMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete message: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRole(role),
|
||||
Content: newContents,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Int("offset", 1, "Offset from the last message to edit")
|
||||
cmd.Flags().StringP("role", "r", "", "Role of the edited message (user or assistant)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
122
pkg/cmd/list.go
Normal file
122
pkg/cmd/list.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
const (
|
||||
LS_COUNT int = 5
|
||||
)
|
||||
|
||||
func ListCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "list",
|
||||
Aliases: []string{"ls"},
|
||||
Short: "List conversations",
|
||||
Long: `List conversations in order of recent activity`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
conversations, err := ctx.Store.Conversations()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not fetch conversations: %v", err)
|
||||
}
|
||||
|
||||
type Category struct {
|
||||
name string
|
||||
cutoff time.Duration
|
||||
}
|
||||
|
||||
type ConversationLine struct {
|
||||
timeSinceReply time.Duration
|
||||
formatted string
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
midnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location())
|
||||
dayOfWeek := int(now.Weekday())
|
||||
categories := []Category{
|
||||
{"today", now.Sub(midnight)},
|
||||
{"yesterday", now.Sub(midnight.AddDate(0, 0, -1))},
|
||||
{"this week", now.Sub(midnight.AddDate(0, 0, -dayOfWeek))},
|
||||
{"last week", now.Sub(midnight.AddDate(0, 0, -(dayOfWeek + 7)))},
|
||||
{"this month", now.Sub(monthStart)},
|
||||
{"last month", now.Sub(monthStart.AddDate(0, -1, 0))},
|
||||
{"2 months ago", now.Sub(monthStart.AddDate(0, -2, 0))},
|
||||
{"3 months ago", now.Sub(monthStart.AddDate(0, -3, 0))},
|
||||
{"4 months ago", now.Sub(monthStart.AddDate(0, -4, 0))},
|
||||
{"5 months ago", now.Sub(monthStart.AddDate(0, -5, 0))},
|
||||
{"older", now.Sub(time.Time{})},
|
||||
}
|
||||
categorized := map[string][]ConversationLine{}
|
||||
|
||||
all, _ := cmd.Flags().GetBool("all")
|
||||
|
||||
for _, conversation := range conversations {
|
||||
lastMessage, err := ctx.Store.LastMessage(&conversation)
|
||||
if lastMessage == nil || err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
messageAge := now.Sub(lastMessage.CreatedAt)
|
||||
|
||||
var category string
|
||||
for _, c := range categories {
|
||||
if messageAge < c.cutoff {
|
||||
category = c.name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
formatted := fmt.Sprintf(
|
||||
"%s - %s - %s",
|
||||
conversation.ShortName.String,
|
||||
util.HumanTimeElapsedSince(messageAge),
|
||||
conversation.Title,
|
||||
)
|
||||
|
||||
categorized[category] = append(
|
||||
categorized[category],
|
||||
ConversationLine{messageAge, formatted},
|
||||
)
|
||||
}
|
||||
|
||||
count, _ := cmd.Flags().GetInt("count")
|
||||
var conversationsPrinted int
|
||||
outer:
|
||||
for _, category := range categories {
|
||||
conversationLines, ok := categorized[category.name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
slices.SortFunc(conversationLines, func(a, b ConversationLine) int {
|
||||
return int(a.timeSinceReply - b.timeSinceReply)
|
||||
})
|
||||
|
||||
fmt.Printf("%s:\n", category.name)
|
||||
for _, conv := range conversationLines {
|
||||
if conversationsPrinted >= count && !all {
|
||||
fmt.Printf("%d remaining message(s), use --all to view.\n", len(conversations)-conversationsPrinted)
|
||||
break outer
|
||||
}
|
||||
|
||||
fmt.Printf(" %s\n", conv.formatted)
|
||||
conversationsPrinted++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("all", false, "Show all conversations")
|
||||
cmd.Flags().Int("count", LS_COUNT, "How many conversations to show")
|
||||
|
||||
return cmd
|
||||
}
|
||||
60
pkg/cmd/new.go
Normal file
60
pkg/cmd/new.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func NewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "new [message]",
|
||||
Short: "Start a new conversation",
|
||||
Long: `Start a new conversation with the Large Language Model.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
messageContents := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if messageContents == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
conversation := &model.Conversation{}
|
||||
err := ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not save new conversation: %v", err)
|
||||
}
|
||||
|
||||
messages := []model.Message{
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: getSystemPrompt(ctx),
|
||||
},
|
||||
{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleUser,
|
||||
Content: messageContents,
|
||||
},
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, messages...)
|
||||
|
||||
title, err := cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not generate title for conversation: %v\n", err)
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation after generating title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
42
pkg/cmd/prompt.go
Normal file
42
pkg/cmd/prompt.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func PromptCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "prompt [message]",
|
||||
Short: "Do a one-shot prompt",
|
||||
Long: `Prompt the Large Language Model and get a response.`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
message := inputFromArgsOrEditor(args, "# What would you like to say?\n", "")
|
||||
if message == "" {
|
||||
return fmt.Errorf("No message was provided.")
|
||||
}
|
||||
|
||||
messages := []model.Message{
|
||||
{
|
||||
Role: model.MessageRoleSystem,
|
||||
Content: getSystemPrompt(ctx),
|
||||
},
|
||||
{
|
||||
Role: model.MessageRoleUser,
|
||||
Content: message,
|
||||
},
|
||||
}
|
||||
|
||||
_, err := cmdutil.FetchAndShowCompletion(ctx, messages, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error fetching LLM response: %v", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
60
pkg/cmd/remove.go
Normal file
60
pkg/cmd/remove.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RemoveCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rm <conversation>...",
|
||||
Short: "Remove conversations",
|
||||
Long: `Remove conversations by their short names.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
var toRemove []*model.Conversation
|
||||
for _, shortName := range args {
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
toRemove = append(toRemove, conversation)
|
||||
}
|
||||
var errors []error
|
||||
for _, c := range toRemove {
|
||||
err := ctx.Store.DeleteConversation(c)
|
||||
if err != nil {
|
||||
errors = append(errors, fmt.Errorf("Could not remove conversation %s: %v", c.ShortName.String, err))
|
||||
}
|
||||
}
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("Could not remove some conversations: %v", errors)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
var completions []string
|
||||
outer:
|
||||
for _, completion := range ctx.Store.ConversationShortNameCompletions(toComplete) {
|
||||
parts := strings.Split(completion, "\t")
|
||||
for _, arg := range args {
|
||||
if parts[0] == arg {
|
||||
continue outer
|
||||
}
|
||||
}
|
||||
completions = append(completions, completion)
|
||||
}
|
||||
return completions, compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
62
pkg/cmd/rename.go
Normal file
62
pkg/cmd/rename.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
)
|
||||
|
||||
func RenameCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "rename <conversation> [title]",
|
||||
Short: "Rename a conversation",
|
||||
Long: `Renames a conversation, either with the provided title or by generating a new name.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
var err error
|
||||
|
||||
generate, _ := cmd.Flags().GetBool("generate")
|
||||
var title string
|
||||
if generate {
|
||||
title, err = cmdutil.GenerateTitle(ctx, conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not generate conversation title: %v", err)
|
||||
}
|
||||
} else {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("Conversation title not provided.")
|
||||
}
|
||||
title = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
conversation.Title = title
|
||||
err = ctx.Store.SaveConversation(conversation)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save conversation with new title: %v\n", err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("generate", false, "Generate a conversation title")
|
||||
|
||||
return cmd
|
||||
}
|
||||
49
pkg/cmd/reply.go
Normal file
49
pkg/cmd/reply.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ReplyCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "reply <conversation> [message]",
|
||||
Short: "Reply to a conversation",
|
||||
Long: `Sends a reply to conversation and writes the response to stdout.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
reply := inputFromArgsOrEditor(args[1:], "# How would you like to reply?\n", "")
|
||||
if reply == "" {
|
||||
return fmt.Errorf("No reply was provided.")
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true, model.Message{
|
||||
ConversationID: conversation.ID,
|
||||
Role: model.MessageRoleUser,
|
||||
Content: reply,
|
||||
})
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
58
pkg/cmd/retry.go
Normal file
58
pkg/cmd/retry.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func RetryCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "retry <conversation>",
|
||||
Short: "Retry the last user reply in a conversation",
|
||||
Long: `Re-prompt the conversation up to the last user response. Can be used to regenerate the last assistant reply, or simply generate one if an error occurred.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
// walk backwards through the conversation and delete messages, break
|
||||
// when we find the latest user response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == model.MessageRoleUser {
|
||||
break
|
||||
}
|
||||
|
||||
err = ctx.Store.DeleteMessage(&messages[i])
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not delete previous reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmdutil.HandleConversationReply(ctx, conversation, true)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
292
pkg/cmd/util/util.go
Normal file
292
pkg/cmd/util/util.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// fetchAndShowCompletion prompts the LLM with the given messages and streams
|
||||
// the response to stdout. Returns all model reply messages.
|
||||
func FetchAndShowCompletion(ctx *lmcli.Context, messages []model.Message, callback func(model.Message)) (string, error) {
|
||||
content := make(chan string) // receives the reponse from LLM
|
||||
defer close(content)
|
||||
|
||||
// render all content received over the channel
|
||||
go ShowDelayedContent(content)
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
Model: *ctx.Config.Defaults.Model,
|
||||
MaxTokens: *ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *ctx.Config.Defaults.Temperature,
|
||||
ToolBag: ctx.EnabledTools,
|
||||
}
|
||||
|
||||
response, err := completionProvider.CreateChatCompletionStream(
|
||||
context.Background(), requestParams, messages, callback, content,
|
||||
)
|
||||
if response != "" {
|
||||
// there was some content, so break to a new line after it
|
||||
fmt.Println()
|
||||
|
||||
if err != nil {
|
||||
lmcli.Warn("Received partial response. Error: %v\n", err)
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// lookupConversation either returns the conversation found by the
|
||||
// short name or exits the program
|
||||
func LookupConversation(ctx *lmcli.Context, shortName string) *model.Conversation {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not lookup conversation: %v\n", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
lmcli.Fatal("Conversation not found with short name: %s\n", shortName)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func LookupConversationE(ctx *lmcli.Context, shortName string) (*model.Conversation, error) {
|
||||
c, err := ctx.Store.ConversationByShortName(shortName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not lookup conversation: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return nil, fmt.Errorf("Conversation not found with short name: %s", shortName)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// handleConversationReply handles sending messages to an existing
|
||||
// conversation, optionally persisting both the sent replies and responses.
|
||||
func HandleConversationReply(ctx *lmcli.Context, c *model.Conversation, persist bool, toSend ...model.Message) {
|
||||
existing, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Could not retrieve messages for conversation: %s\n", c.Title)
|
||||
}
|
||||
|
||||
if persist {
|
||||
for _, message := range toSend {
|
||||
err = ctx.Store.SaveMessage(&message)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save %s message: %v\n", message.Role, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allMessages := append(existing, toSend...)
|
||||
|
||||
RenderConversation(ctx, allMessages, true)
|
||||
|
||||
// render a message header with no contents
|
||||
RenderMessage(ctx, (&model.Message{Role: model.MessageRoleAssistant}))
|
||||
|
||||
replyCallback := func(reply model.Message) {
|
||||
if !persist {
|
||||
return
|
||||
}
|
||||
|
||||
reply.ConversationID = c.ID
|
||||
err = ctx.Store.SaveMessage(&reply)
|
||||
if err != nil {
|
||||
lmcli.Warn("Could not save reply: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = FetchAndShowCompletion(ctx, allMessages, replyCallback)
|
||||
if err != nil {
|
||||
lmcli.Fatal("Error fetching LLM response: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func FormatForExternalPrompt(messages []model.Message, system bool) string {
|
||||
sb := strings.Builder{}
|
||||
for _, message := range messages {
|
||||
if message.Content == "" {
|
||||
continue
|
||||
}
|
||||
switch message.Role {
|
||||
case model.MessageRoleAssistant, model.MessageRoleToolCall:
|
||||
sb.WriteString("Assistant:\n\n")
|
||||
case model.MessageRoleUser:
|
||||
sb.WriteString("User:\n\n")
|
||||
default:
|
||||
continue
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s", lipgloss.NewStyle().PaddingLeft(1).Render(message.Content)))
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func GenerateTitle(ctx *lmcli.Context, c *model.Conversation) (string, error) {
|
||||
messages, err := ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
const prompt = `Above is an excerpt from a conversation between a user and AI assistant. Please reply with a short title (no more than 8 words) that reflects the topic of the conversation, read from the user's perspective.
|
||||
|
||||
Example conversation:
|
||||
|
||||
"""
|
||||
User:
|
||||
|
||||
Hello!
|
||||
|
||||
Assistant:
|
||||
|
||||
Hello! How may I assist you?
|
||||
"""
|
||||
|
||||
Example response:
|
||||
|
||||
"""
|
||||
Title: A brief introduction
|
||||
"""
|
||||
`
|
||||
conversation := FormatForExternalPrompt(messages, false)
|
||||
|
||||
generateRequest := []model.Message{
|
||||
{
|
||||
Role: model.MessageRoleUser,
|
||||
Content: fmt.Sprintf("\"\"\"\n%s\n\"\"\"\n\n%s", conversation, prompt),
|
||||
},
|
||||
}
|
||||
|
||||
completionProvider, err := ctx.GetCompletionProvider(*ctx.Config.Conversations.TitleGenerationModel)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
requestParams := model.RequestParameters{
|
||||
Model: *ctx.Config.Conversations.TitleGenerationModel,
|
||||
MaxTokens: 25,
|
||||
}
|
||||
|
||||
response, err := completionProvider.CreateChatCompletion(context.Background(), requestParams, generateRequest, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
response = strings.TrimPrefix(response, "Title: ")
|
||||
response = strings.Trim(response, "\"")
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ShowWaitAnimation prints an animated ellipses to stdout until something is
|
||||
// received on the signal channel. An empty string sent to the channel to
|
||||
// notify the caller that the animation has completed (carriage returned).
|
||||
func ShowWaitAnimation(signal chan any) {
|
||||
// Save the current cursor position
|
||||
fmt.Print("\033[s")
|
||||
|
||||
animationStep := 0
|
||||
for {
|
||||
select {
|
||||
case _ = <-signal:
|
||||
// Relmcli the cursor position
|
||||
fmt.Print("\033[u")
|
||||
signal <- ""
|
||||
return
|
||||
default:
|
||||
// Move the cursor to the saved position
|
||||
modSix := animationStep % 6
|
||||
if modSix == 3 || modSix == 0 {
|
||||
fmt.Print("\033[u")
|
||||
}
|
||||
if modSix < 3 {
|
||||
fmt.Print(".")
|
||||
} else {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
animationStep++
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ShowDelayedContent displays a waiting animation to stdout while waiting
|
||||
// for content to be received on the provided channel. As soon as any (possibly
|
||||
// chunked) content is received on the channel, the waiting animation is
|
||||
// replaced by the content.
|
||||
// Blocks until the channel is closed.
|
||||
func ShowDelayedContent(content <-chan string) {
|
||||
waitSignal := make(chan any)
|
||||
go ShowWaitAnimation(waitSignal)
|
||||
|
||||
firstChunk := true
|
||||
for chunk := range content {
|
||||
if firstChunk {
|
||||
// notify wait animation that we've received data
|
||||
waitSignal <- ""
|
||||
// wait for signal that wait animation has completed
|
||||
<-waitSignal
|
||||
firstChunk = false
|
||||
}
|
||||
fmt.Print(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
// RenderConversation renders the given messages to TTY, with optional space
|
||||
// for a subsequent message. spaceForResponse controls how many '\n' characters
|
||||
// are printed immediately after the final message (1 if false, 2 if true)
|
||||
func RenderConversation(ctx *lmcli.Context, messages []model.Message, spaceForResponse bool) {
|
||||
l := len(messages)
|
||||
for i, message := range messages {
|
||||
RenderMessage(ctx, &message)
|
||||
if i < l-1 || spaceForResponse {
|
||||
// print an additional space before the next message
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RenderMessage(ctx *lmcli.Context, m *model.Message) {
|
||||
var messageAge string
|
||||
if m.CreatedAt.IsZero() {
|
||||
messageAge = "now"
|
||||
} else {
|
||||
now := time.Now()
|
||||
messageAge = util.HumanTimeElapsedSince(now.Sub(m.CreatedAt))
|
||||
}
|
||||
|
||||
headerStyle := lipgloss.NewStyle().Bold(true)
|
||||
|
||||
switch m.Role {
|
||||
case model.MessageRoleSystem:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("9")) // bright red
|
||||
case model.MessageRoleUser:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("10")) // bright green
|
||||
case model.MessageRoleAssistant:
|
||||
headerStyle = headerStyle.Foreground(lipgloss.Color("12")) // bright blue
|
||||
}
|
||||
|
||||
role := headerStyle.Render(m.Role.FriendlyRole())
|
||||
|
||||
separatorStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("3"))
|
||||
separator := separatorStyle.Render("===")
|
||||
timestamp := separatorStyle.Render(messageAge)
|
||||
|
||||
fmt.Printf("%s %s - %s %s\n\n", separator, role, timestamp, separator)
|
||||
if m.Content != "" {
|
||||
ctx.Chroma.Highlight(os.Stdout, m.Content)
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
45
pkg/cmd/view.go
Normal file
45
pkg/cmd/view.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func ViewCmd(ctx *lmcli.Context) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "view <conversation>",
|
||||
Short: "View messages in a conversation",
|
||||
Long: `Finds a conversation by its short name and displays its contents.`,
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
argCount := 1
|
||||
if err := cobra.MinimumNArgs(argCount)(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
shortName := args[0]
|
||||
conversation := cmdutil.LookupConversation(ctx, shortName)
|
||||
|
||||
messages, err := ctx.Store.Messages(conversation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not retrieve messages for conversation: %s", conversation.Title)
|
||||
}
|
||||
|
||||
cmdutil.RenderConversation(ctx, messages, false)
|
||||
return nil
|
||||
},
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
compMode := cobra.ShellCompDirectiveNoFileComp
|
||||
if len(args) != 0 {
|
||||
return nil, compMode
|
||||
}
|
||||
return ctx.Store.ConversationShortNameCompletions(toComplete), compMode
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
71
pkg/lmcli/config.go
Normal file
71
pkg/lmcli/config.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package lmcli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util"
|
||||
"github.com/go-yaml/yaml"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Defaults *struct {
|
||||
SystemPrompt *string `yaml:"systemPrompt" default:"You are a helpful assistant."`
|
||||
MaxTokens *int `yaml:"maxTokens" default:"256"`
|
||||
Temperature *float32 `yaml:"temperature" default:"0.7"`
|
||||
Model *string `yaml:"model" default:"gpt-4"`
|
||||
} `yaml:"defaults"`
|
||||
Conversations *struct {
|
||||
TitleGenerationModel *string `yaml:"titleGenerationModel" default:"gpt-3.5-turbo"`
|
||||
} `yaml:"conversations"`
|
||||
Tools *struct {
|
||||
EnabledTools *[]string `yaml:"enabledTools"`
|
||||
} `yaml:"tools"`
|
||||
OpenAI *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
Models *[]string `yaml:"models"`
|
||||
} `yaml:"openai"`
|
||||
Anthropic *struct {
|
||||
APIKey *string `yaml:"apiKey" default:"your_key_here"`
|
||||
Models *[]string `yaml:"models"`
|
||||
} `yaml:"anthropic"`
|
||||
Chroma *struct {
|
||||
Style *string `yaml:"style" default:"onedark"`
|
||||
Formatter *string `yaml:"formatter" default:"terminal16m"`
|
||||
} `yaml:"chroma"`
|
||||
}
|
||||
|
||||
func NewConfig(configFile string) (*Config, error) {
|
||||
shouldWriteDefaults := false
|
||||
c := &Config{}
|
||||
|
||||
configExists := true
|
||||
configBytes, err := os.ReadFile(configFile)
|
||||
if os.IsNotExist(err) {
|
||||
configExists = false
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("Could not read config file: %v", err)
|
||||
} else {
|
||||
yaml.Unmarshal(configBytes, c)
|
||||
}
|
||||
|
||||
shouldWriteDefaults = util.SetStructDefaults(c)
|
||||
if !configExists || shouldWriteDefaults {
|
||||
if configExists {
|
||||
fmt.Printf("Saving new defaults to configuration, backing up existing configuration to %s\n", configFile+".bak")
|
||||
os.Rename(configFile, configFile+".bak")
|
||||
}
|
||||
fmt.Printf("Writing configuration file to %s\n", configFile)
|
||||
file, err := os.Create(configFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not open config file for writing: %v", err)
|
||||
}
|
||||
bytes, _ := yaml.Marshal(c)
|
||||
_, err = file.Write(bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not save default configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
123
pkg/lmcli/lmcli.go
Normal file
123
pkg/lmcli/lmcli.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package lmcli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/anthropic"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider/openai"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/util/tty"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Context struct {
|
||||
Config *Config
|
||||
Store ConversationStore
|
||||
|
||||
Chroma *tty.ChromaHighlighter
|
||||
EnabledTools []model.Tool
|
||||
}
|
||||
|
||||
func NewContext() (*Context, error) {
|
||||
configFile := filepath.Join(configDir(), "config.yaml")
|
||||
config, err := NewConfig(configFile)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
databaseFile := filepath.Join(dataDir(), "conversations.db")
|
||||
db, err := gorm.Open(sqlite.Open(databaseFile), &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error establishing connection to store: %v", err)
|
||||
}
|
||||
store, err := NewSQLStore(db)
|
||||
if err != nil {
|
||||
Fatal("%v\n", err)
|
||||
}
|
||||
|
||||
chroma := tty.NewChromaHighlighter("markdown", *config.Chroma.Formatter, *config.Chroma.Style)
|
||||
|
||||
var enabledTools []model.Tool
|
||||
for _, toolName := range *config.Tools.EnabledTools {
|
||||
tool, ok := tools.AvailableTools[toolName]
|
||||
if ok {
|
||||
enabledTools = append(enabledTools, tool)
|
||||
}
|
||||
}
|
||||
|
||||
return &Context{config, store, chroma, enabledTools}, nil
|
||||
}
|
||||
|
||||
func (c *Context) GetModels() (models []string) {
|
||||
for _, m := range *c.Config.Anthropic.Models {
|
||||
models = append(models, m)
|
||||
}
|
||||
for _, m := range *c.Config.OpenAI.Models {
|
||||
models = append(models, m)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Context) GetCompletionProvider(model string) (provider.ChatCompletionClient, error) {
|
||||
for _, m := range *c.Config.Anthropic.Models {
|
||||
if m == model {
|
||||
anthropic := &anthropic.AnthropicClient{
|
||||
APIKey: *c.Config.Anthropic.APIKey,
|
||||
}
|
||||
return anthropic, nil
|
||||
}
|
||||
}
|
||||
for _, m := range *c.Config.OpenAI.Models {
|
||||
if m == model {
|
||||
openai := &openai.OpenAIClient{
|
||||
APIKey: *c.Config.OpenAI.APIKey,
|
||||
}
|
||||
return openai, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("unknown model: %s", model)
|
||||
}
|
||||
|
||||
func configDir() string {
|
||||
var configDir string
|
||||
|
||||
xdgConfigHome := os.Getenv("XDG_CONFIG_HOME")
|
||||
if xdgConfigHome != "" {
|
||||
configDir = filepath.Join(xdgConfigHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
configDir = filepath.Join(userHomeDir, ".config/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(configDir, 0755)
|
||||
return configDir
|
||||
}
|
||||
|
||||
func dataDir() string {
|
||||
var dataDir string
|
||||
|
||||
xdgDataHome := os.Getenv("XDG_DATA_HOME")
|
||||
if xdgDataHome != "" {
|
||||
dataDir = filepath.Join(xdgDataHome, "lmcli")
|
||||
} else {
|
||||
userHomeDir, _ := os.UserHomeDir()
|
||||
dataDir = filepath.Join(userHomeDir, ".local/share/lmcli")
|
||||
}
|
||||
|
||||
os.MkdirAll(dataDir, 0755)
|
||||
return dataDir
|
||||
}
|
||||
|
||||
func Fatal(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
func Warn(format string, args ...any) {
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
}
|
||||
58
pkg/lmcli/model/conversation.go
Normal file
58
pkg/lmcli/model/conversation.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MessageRole string
|
||||
|
||||
const (
|
||||
MessageRoleSystem MessageRole = "system"
|
||||
MessageRoleUser MessageRole = "user"
|
||||
MessageRoleAssistant MessageRole = "assistant"
|
||||
MessageRoleToolCall MessageRole = "tool_call"
|
||||
MessageRoleToolResult MessageRole = "tool_result"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ConversationID uint `gorm:"foreignKey:ConversationID"`
|
||||
Content string
|
||||
Role MessageRole
|
||||
CreatedAt time.Time
|
||||
ToolCalls ToolCalls // a json array of tool calls (from the modl)
|
||||
ToolResults ToolResults // a json array of tool results
|
||||
}
|
||||
|
||||
type Conversation struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
ShortName sql.NullString
|
||||
Title string
|
||||
}
|
||||
|
||||
type RequestParameters struct {
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
|
||||
SystemPrompt string
|
||||
ToolBag []Tool
|
||||
}
|
||||
|
||||
// FriendlyRole returns a human friendly signifier for the message's role.
|
||||
func (m *MessageRole) FriendlyRole() string {
|
||||
var friendlyRole string
|
||||
switch *m {
|
||||
case MessageRoleUser:
|
||||
friendlyRole = "You"
|
||||
case MessageRoleSystem:
|
||||
friendlyRole = "System"
|
||||
case MessageRoleAssistant:
|
||||
friendlyRole = "Assistant"
|
||||
default:
|
||||
friendlyRole = string(*m)
|
||||
}
|
||||
return friendlyRole
|
||||
}
|
||||
98
pkg/lmcli/model/tool.go
Normal file
98
pkg/lmcli/model/tool.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Description string
|
||||
Parameters []ToolParameter
|
||||
Impl func(*Tool, map[string]interface{}) (string, error)
|
||||
}
|
||||
|
||||
type ToolParameter struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"` // "string", "integer", "boolean"
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
type ToolCalls []ToolCall
|
||||
|
||||
func (tc *ToolCalls) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tc = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tc)
|
||||
return
|
||||
}
|
||||
|
||||
func (tc ToolCalls) Value() (driver.Value, error) {
|
||||
if len(tc) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal(tc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolCalls to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolCallID string `json:"toolCallID"`
|
||||
ToolName string `json:"toolName,omitempty"`
|
||||
Result string `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type ToolResults []ToolResult
|
||||
|
||||
func (tr *ToolResults) Scan(value any) (err error) {
|
||||
s := value.(string)
|
||||
if value == nil || s == "" {
|
||||
*tr = nil
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal([]byte(s), tr)
|
||||
return
|
||||
}
|
||||
|
||||
func (tr ToolResults) Value() (driver.Value, error) {
|
||||
if len(tr) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
jsonBytes, err := json.Marshal([]ToolResult(tr))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal ToolResults to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
type CallResult struct {
|
||||
Message string `json:"message"`
|
||||
Result any `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
func (r CallResult) ToJson() (string, error) {
|
||||
if r.Message == "" {
|
||||
// When message not supplied, assume success
|
||||
r.Message = "success"
|
||||
}
|
||||
|
||||
jsonBytes, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not marshal CallResult to JSON: %v\n", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
348
pkg/lmcli/provider/anthropic/anthropic.go
Normal file
348
pkg/lmcli/provider/anthropic/anthropic.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
)
|
||||
|
||||
type AnthropicClient struct {
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
//TopP float32 `json:"top_p,omitempty"`
|
||||
//TopK float32 `json:"top_k,omitempty"`
|
||||
}
|
||||
|
||||
type OriginalContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Id string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Role string `json:"role"`
|
||||
Content []OriginalContent `json:"content"`
|
||||
}
|
||||
|
||||
const FUNCTION_STOP_SEQUENCE = "</function_calls>"
|
||||
|
||||
func buildRequest(params model.RequestParameters, messages []model.Message) Request {
|
||||
requestBody := Request{
|
||||
Model: params.Model,
|
||||
Messages: make([]Message, len(messages)),
|
||||
System: params.SystemPrompt,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Stream: false,
|
||||
|
||||
StopSequences: []string{
|
||||
FUNCTION_STOP_SEQUENCE,
|
||||
"\n\nHuman:",
|
||||
},
|
||||
}
|
||||
|
||||
startIdx := 0
|
||||
if len(messages) > 0 && messages[0].Role == model.MessageRoleSystem {
|
||||
requestBody.System = messages[0].Content
|
||||
requestBody.Messages = requestBody.Messages[1:]
|
||||
startIdx = 1
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
if len(requestBody.System) > 0 {
|
||||
// add a divider between existing system prompt and tools
|
||||
requestBody.System += "\n\n---\n\n"
|
||||
}
|
||||
requestBody.System += buildToolsSystemPrompt(params.ToolBag)
|
||||
}
|
||||
|
||||
for i, msg := range messages[startIdx:] {
|
||||
message := &requestBody.Messages[i]
|
||||
|
||||
switch msg.Role {
|
||||
case model.MessageRoleToolCall:
|
||||
message.Role = "assistant"
|
||||
if msg.Content != "" {
|
||||
message.Content = msg.Content
|
||||
}
|
||||
xmlFuncCalls := convertToolCallsToXMLFunctionCalls(msg.ToolCalls)
|
||||
xmlString, err := xmlFuncCalls.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolCall to XMLFunctionCall")
|
||||
}
|
||||
if len(message.Content) > 0 {
|
||||
message.Content += fmt.Sprintf("\n\n%s", xmlString)
|
||||
} else {
|
||||
message.Content = xmlString
|
||||
}
|
||||
case model.MessageRoleToolResult:
|
||||
xmlFuncResults := convertToolResultsToXMLFunctionResult(msg.ToolResults)
|
||||
xmlString, err := xmlFuncResults.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []ToolResult to XMLFunctionResults")
|
||||
}
|
||||
message.Role = "user"
|
||||
message.Content = xmlString
|
||||
default:
|
||||
message.Role = string(msg.Role)
|
||||
message.Content = msg.Content
|
||||
}
|
||||
}
|
||||
return requestBody
|
||||
}
|
||||
|
||||
func sendRequest(ctx context.Context, c *AnthropicClient, r Request) (*http.Response, error) {
|
||||
url := "https://api.anthropic.com/v1/messages"
|
||||
|
||||
jsonBody, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request body: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("x-api-key", c.APIKey)
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send HTTP request: %v", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response Response
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
sb := strings.Builder{}
|
||||
for _, content := range response.Content {
|
||||
var reply model.Message
|
||||
switch content.Type {
|
||||
case "text":
|
||||
reply = model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.Text,
|
||||
}
|
||||
sb.WriteString(reply.Content)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported message type: %s", content.Type)
|
||||
}
|
||||
if callback != nil {
|
||||
callback(reply)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (c *AnthropicClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
request := buildRequest(params, messages)
|
||||
request.Stream = true
|
||||
|
||||
resp, err := sendRequest(ctx, c, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
sb := strings.Builder{}
|
||||
|
||||
isToolCall := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if line[0] == '{' {
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(line), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data '%s': %v", line, err)
|
||||
}
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event: %s", line)
|
||||
}
|
||||
switch eventType {
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
return sb.String(), fmt.Errorf("unknown event type: %s", eventType)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "data:") {
|
||||
data := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
||||
var event map[string]interface{}
|
||||
err := json.Unmarshal([]byte(data), &event)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal event data: %v", err)
|
||||
}
|
||||
|
||||
eventType, ok := event["type"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid event type")
|
||||
}
|
||||
|
||||
switch eventType {
|
||||
case "message_start":
|
||||
// noop
|
||||
case "ping":
|
||||
// write an empty string to signal start of text
|
||||
output <- ""
|
||||
case "content_block_start":
|
||||
// ignore?
|
||||
case "content_block_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid content block delta")
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid text delta")
|
||||
}
|
||||
sb.WriteString(text)
|
||||
output <- text
|
||||
case "content_block_stop":
|
||||
// ignore?
|
||||
case "message_delta":
|
||||
delta, ok := event["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid message delta")
|
||||
}
|
||||
stopReason, ok := delta["stop_reason"].(string)
|
||||
if ok && stopReason == "stop_sequence" {
|
||||
stopSequence, ok := delta["stop_sequence"].(string)
|
||||
if ok && stopSequence == FUNCTION_STOP_SEQUENCE {
|
||||
content := sb.String()
|
||||
|
||||
start := strings.Index(content, "<function_calls>")
|
||||
if start == -1 {
|
||||
return content, fmt.Errorf("reached </function_calls> stop sequence but no opening tag found")
|
||||
}
|
||||
|
||||
isToolCall = true
|
||||
|
||||
funcCallXml := content[start:]
|
||||
funcCallXml += FUNCTION_STOP_SEQUENCE
|
||||
|
||||
sb.WriteString(FUNCTION_STOP_SEQUENCE)
|
||||
output <- FUNCTION_STOP_SEQUENCE
|
||||
|
||||
// Extract function calls
|
||||
var functionCalls XMLFunctionCalls
|
||||
err := xml.Unmarshal([]byte(sb.String()), &functionCalls)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to unmarshal function_calls: %v", err)
|
||||
}
|
||||
|
||||
// Execute function calls
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
// xml stripped from content
|
||||
Content: content[:start],
|
||||
ToolCalls: convertXMLFunctionCallsToToolCalls(functionCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
toolReply := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(toolCall)
|
||||
callback(toolReply)
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
// added to the original messages
|
||||
messages = append(append(messages, toolCall), toolReply)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
}
|
||||
}
|
||||
case "message_stop":
|
||||
// return the completed message
|
||||
if callback != nil {
|
||||
if !isToolCall {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: sb.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return sb.String(), nil
|
||||
case "error":
|
||||
return sb.String(), fmt.Errorf("an error occurred: %s", event["error"])
|
||||
default:
|
||||
fmt.Printf("\nUnrecognized event: %s\n", data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unexpected end of stream")
|
||||
}
|
||||
230
pkg/lmcli/provider/anthropic/tools.go
Normal file
230
pkg/lmcli/provider/anthropic/tools.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const TOOL_PREAMBLE = `You have access to the following tools when replying.
|
||||
|
||||
You may call them like this:
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:`
|
||||
|
||||
const TOOL_PREAMBLE_FOOTER = `Recognize the utility of these tools in a broad range of different applications, and the power they give you to solve a wide range of different problems. However, ensure that the tools are used judiciously and only when clearly relevant to the user's request. Specifically:
|
||||
|
||||
1. Only use a tool if the user has explicitly requested or provided information that warrants its use. Do not make assumptions about files or data existing without the user mentioning them.
|
||||
|
||||
2. If there is ambiguity about whether using a tool is appropriate, ask a clarifying question to the user before proceeding. Confirm your understanding of their request and intent.
|
||||
|
||||
3. Prioritize providing direct responses and explanations based on your own knowledge and understanding. Use tools to supplement and enhance your responses when clearly applicable, but not as a default action.`
|
||||
|
||||
type XMLTools struct {
|
||||
XMLName struct{} `xml:"tools"`
|
||||
ToolDescriptions []XMLToolDescription `xml:"tool_description"`
|
||||
}
|
||||
|
||||
type XMLToolDescription struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Description string `xml:"description"`
|
||||
Parameters []XMLToolParameter `xml:"parameters>parameter"`
|
||||
}
|
||||
|
||||
type XMLToolParameter struct {
|
||||
Name string `xml:"name"`
|
||||
Type string `xml:"type"`
|
||||
Description string `xml:"description"`
|
||||
}
|
||||
|
||||
type XMLFunctionCalls struct {
|
||||
XMLName struct{} `xml:"function_calls"`
|
||||
Invoke []XMLFunctionInvoke `xml:"invoke"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvoke struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Parameters XMLFunctionInvokeParameters `xml:"parameters"`
|
||||
}
|
||||
|
||||
type XMLFunctionInvokeParameters struct {
|
||||
String string `xml:",innerxml"`
|
||||
}
|
||||
|
||||
type XMLFunctionResults struct {
|
||||
XMLName struct{} `xml:"function_results"`
|
||||
Result []XMLFunctionResult `xml:"result"`
|
||||
}
|
||||
|
||||
type XMLFunctionResult struct {
|
||||
ToolName string `xml:"tool_name"`
|
||||
Stdout string `xml:"stdout"`
|
||||
}
|
||||
|
||||
// accepts raw XML from XMLFunctionInvokeParameters.String, returns map of
|
||||
// parameters name to value
|
||||
func parseFunctionParametersXML(params string) map[string]interface{} {
|
||||
lines := strings.Split(params, "\n")
|
||||
ret := make(map[string]interface{}, len(lines))
|
||||
for _, line := range lines {
|
||||
i := strings.Index(line, ">")
|
||||
if i == -1 {
|
||||
continue
|
||||
}
|
||||
j := strings.Index(line, "</")
|
||||
if j == -1 {
|
||||
continue
|
||||
}
|
||||
// chop from after opening < to first > to get parameter name,
|
||||
// then chop after > to first </ to get parameter value
|
||||
ret[line[1:i]] = line[i+1 : j]
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func convertToolsToXMLTools(tools []model.Tool) XMLTools {
|
||||
converted := make([]XMLToolDescription, len(tools))
|
||||
for i, tool := range tools {
|
||||
converted[i].ToolName = tool.Name
|
||||
converted[i].Description = tool.Description
|
||||
|
||||
params := make([]XMLToolParameter, len(tool.Parameters))
|
||||
for j, param := range tool.Parameters {
|
||||
params[j].Name = param.Name
|
||||
params[j].Description = param.Description
|
||||
params[j].Type = param.Type
|
||||
}
|
||||
|
||||
converted[i].Parameters = params
|
||||
}
|
||||
return XMLTools{
|
||||
ToolDescriptions: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertXMLFunctionCallsToToolCalls(functionCalls XMLFunctionCalls) []model.ToolCall {
|
||||
toolCalls := make([]model.ToolCall, len(functionCalls.Invoke))
|
||||
for i, invoke := range functionCalls.Invoke {
|
||||
toolCalls[i].Name = invoke.ToolName
|
||||
toolCalls[i].Parameters = parseFunctionParametersXML(invoke.Parameters.String)
|
||||
}
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func convertToolCallsToXMLFunctionCalls(toolCalls []model.ToolCall) XMLFunctionCalls {
|
||||
converted := make([]XMLFunctionInvoke, len(toolCalls))
|
||||
for i, toolCall := range toolCalls {
|
||||
var params XMLFunctionInvokeParameters
|
||||
var paramXML string
|
||||
for key, value := range toolCall.Parameters {
|
||||
paramXML += fmt.Sprintf("<%s>%v</%s>\n", key, value, key)
|
||||
}
|
||||
params.String = paramXML
|
||||
converted[i] = XMLFunctionInvoke{
|
||||
ToolName: toolCall.Name,
|
||||
Parameters: params,
|
||||
}
|
||||
}
|
||||
return XMLFunctionCalls{
|
||||
Invoke: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func convertToolResultsToXMLFunctionResult(toolResults []model.ToolResult) XMLFunctionResults {
|
||||
converted := make([]XMLFunctionResult, len(toolResults))
|
||||
for i, result := range toolResults {
|
||||
converted[i].ToolName = result.ToolName
|
||||
converted[i].Stdout = result.Result
|
||||
}
|
||||
return XMLFunctionResults{
|
||||
Result: converted,
|
||||
}
|
||||
}
|
||||
|
||||
func buildToolsSystemPrompt(tools []model.Tool) string {
|
||||
xmlTools := convertToolsToXMLTools(tools)
|
||||
xmlToolsString, err := xmlTools.XMLString()
|
||||
if err != nil {
|
||||
panic("Could not serialize []model.Tool to XMLTools")
|
||||
}
|
||||
return TOOL_PREAMBLE + "\n\n" + xmlToolsString + "\n\n" + TOOL_PREAMBLE_FOOTER
|
||||
}
|
||||
|
||||
func (x XMLTools) XMLString() (string, error) {
|
||||
tmpl, err := template.New("tools").Parse(`<tools>
|
||||
{{range .ToolDescriptions}}<tool_description>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<description>
|
||||
{{.Description}}
|
||||
</description>
|
||||
<parameters>
|
||||
{{range .Parameters}}<parameter>
|
||||
<name>{{.Name}}</name>
|
||||
<type>{{.Type}}</type>
|
||||
<description>{{.Description}}</description>
|
||||
</parameter>
|
||||
{{end}}</parameters>
|
||||
</tool_description>
|
||||
{{end}}</tools>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionResults) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_results").Parse(`<function_results>
|
||||
{{range .Result}}<result>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<stdout>{{.Stdout}}</stdout>
|
||||
</result>
|
||||
{{end}}</function_results>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (x XMLFunctionCalls) XMLString() (string, error) {
|
||||
tmpl, err := template.New("function_calls").Parse(`<function_calls>
|
||||
{{range .Invoke}}<invoke>
|
||||
<tool_name>{{.ToolName}}</tool_name>
|
||||
<parameters>{{.Parameters.String}}</parameters>
|
||||
</invoke>
|
||||
{{end}}</function_calls>`)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, x); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
278
pkg/lmcli/provider/openai/openai.go
Normal file
278
pkg/lmcli/provider/openai/openai.go
Normal file
@@ -0,0 +1,278 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/provider"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/tools"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type OpenAIToolParameters struct {
|
||||
Type string `json:"type"`
|
||||
Properties map[string]OpenAIToolParameter `json:"properties,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIToolParameter struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}
|
||||
|
||||
func convertTools(tools []model.Tool) []openai.Tool {
|
||||
openaiTools := make([]openai.Tool, len(tools))
|
||||
for i, tool := range tools {
|
||||
openaiTools[i].Type = "function"
|
||||
|
||||
params := make(map[string]OpenAIToolParameter)
|
||||
var required []string
|
||||
|
||||
for _, param := range tool.Parameters {
|
||||
params[param.Name] = OpenAIToolParameter{
|
||||
Type: param.Type,
|
||||
Description: param.Description,
|
||||
Enum: param.Enum,
|
||||
}
|
||||
if param.Required {
|
||||
required = append(required, param.Name)
|
||||
}
|
||||
}
|
||||
|
||||
openaiTools[i].Function = openai.FunctionDefinition{
|
||||
Name: tool.Name,
|
||||
Description: tool.Description,
|
||||
Parameters: OpenAIToolParameters{
|
||||
Type: "object",
|
||||
Properties: params,
|
||||
Required: required,
|
||||
},
|
||||
}
|
||||
}
|
||||
return openaiTools
|
||||
}
|
||||
|
||||
func convertToolCallToOpenAI(toolCalls []model.ToolCall) []openai.ToolCall {
|
||||
converted := make([]openai.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].Type = "function"
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Function.Name = call.Name
|
||||
|
||||
json, _ := json.Marshal(call.Parameters)
|
||||
converted[i].Function.Arguments = string(json)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func convertToolCallToAPI(toolCalls []openai.ToolCall) []model.ToolCall {
|
||||
converted := make([]model.ToolCall, len(toolCalls))
|
||||
for i, call := range toolCalls {
|
||||
converted[i].ID = call.ID
|
||||
converted[i].Name = call.Function.Name
|
||||
json.Unmarshal([]byte(call.Function.Arguments), &converted[i].Parameters)
|
||||
}
|
||||
return converted
|
||||
}
|
||||
|
||||
func createChatCompletionRequest(
|
||||
c *OpenAIClient,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
) openai.ChatCompletionRequest {
|
||||
requestMessages := make([]openai.ChatCompletionMessage, 0, len(messages))
|
||||
|
||||
for _, m := range messages {
|
||||
switch m.Role {
|
||||
case "tool_call":
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = "assistant"
|
||||
message.Content = m.Content
|
||||
message.ToolCalls = convertToolCallToOpenAI(m.ToolCalls)
|
||||
requestMessages = append(requestMessages, message)
|
||||
case "tool_result":
|
||||
// expand tool_result messages' results into multiple openAI messages
|
||||
for _, result := range m.ToolResults {
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = "tool"
|
||||
message.Content = result.Result
|
||||
message.ToolCallID = result.ToolCallID
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
default:
|
||||
message := openai.ChatCompletionMessage{}
|
||||
message.Role = string(m.Role)
|
||||
message.Content = m.Content
|
||||
requestMessages = append(requestMessages, message)
|
||||
}
|
||||
}
|
||||
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: params.Model,
|
||||
MaxTokens: params.MaxTokens,
|
||||
Temperature: params.Temperature,
|
||||
Messages: requestMessages,
|
||||
N: 1, // limit responses to 1 "choice". we use choices[0] to reference it
|
||||
}
|
||||
|
||||
if len(params.ToolBag) > 0 {
|
||||
request.Tools = convertTools(params.ToolBag)
|
||||
request.ToolChoice = "auto"
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
func handleToolCalls(
|
||||
params model.RequestParameters,
|
||||
content string,
|
||||
toolCalls []openai.ToolCall,
|
||||
) ([]model.Message, error) {
|
||||
toolCall := model.Message{
|
||||
Role: model.MessageRoleToolCall,
|
||||
Content: content,
|
||||
ToolCalls: convertToolCallToAPI(toolCalls),
|
||||
}
|
||||
|
||||
toolResults, err := tools.ExecuteToolCalls(toolCall.ToolCalls, params.ToolBag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toolResult := model.Message{
|
||||
Role: model.MessageRoleToolResult,
|
||||
ToolResults: toolResults,
|
||||
}
|
||||
|
||||
return []model.Message{toolCall, toolResult}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
resp, err := client.CreateChatCompletion(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
|
||||
toolCalls := choice.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, choice.Message.Content, toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletion with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletion(ctx, params, messages, callback)
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: choice.Message.Content,
|
||||
})
|
||||
}
|
||||
|
||||
// Return the user-facing message.
|
||||
return choice.Message.Content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback provider.ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error) {
|
||||
client := openai.NewClient(c.APIKey)
|
||||
req := createChatCompletionRequest(c, params, messages)
|
||||
|
||||
stream, err := client.CreateChatCompletionStream(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
content := strings.Builder{}
|
||||
toolCalls := []openai.ToolCall{}
|
||||
|
||||
// Iterate stream segments
|
||||
for {
|
||||
response, e := stream.Recv()
|
||||
if errors.Is(e, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
|
||||
delta := response.Choices[0].Delta
|
||||
if len(delta.ToolCalls) > 0 {
|
||||
// Construct streamed tool_call arguments
|
||||
for _, tc := range delta.ToolCalls {
|
||||
if tc.Index == nil {
|
||||
return "", fmt.Errorf("Unexpected nil index for streamed tool call.")
|
||||
}
|
||||
if len(toolCalls) <= *tc.Index {
|
||||
toolCalls = append(toolCalls, tc)
|
||||
} else {
|
||||
toolCalls[*tc.Index].Function.Arguments += tc.Function.Arguments
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output <- delta.Content
|
||||
content.WriteString(delta.Content)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toolCalls) > 0 {
|
||||
results, err := handleToolCalls(params, content.String(), toolCalls)
|
||||
if err != nil {
|
||||
return content.String(), err
|
||||
}
|
||||
|
||||
if callback != nil {
|
||||
for _, result := range results {
|
||||
callback(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into CreateChatCompletionStream with the tool call replies
|
||||
messages = append(messages, results...)
|
||||
return c.CreateChatCompletionStream(ctx, params, messages, callback, output)
|
||||
} else {
|
||||
if callback != nil {
|
||||
callback(model.Message{
|
||||
Role: model.MessageRoleAssistant,
|
||||
Content: content.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return content.String(), err
|
||||
}
|
||||
31
pkg/lmcli/provider/provider.go
Normal file
31
pkg/lmcli/provider/provider.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
type ReplyCallback func(model.Message)
|
||||
|
||||
type ChatCompletionClient interface {
|
||||
// CreateChatCompletion requests a response to the provided messages.
|
||||
// Replies are appended to the given replies struct, and the
|
||||
// complete user-facing response is returned as a string.
|
||||
CreateChatCompletion(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback ReplyCallback,
|
||||
) (string, error)
|
||||
|
||||
// Like CreateChageCompletion, except the response is streamed via
|
||||
// the output channel as it's received.
|
||||
CreateChatCompletionStream(
|
||||
ctx context.Context,
|
||||
params model.RequestParameters,
|
||||
messages []model.Message,
|
||||
callback ReplyCallback,
|
||||
output chan<- string,
|
||||
) (string, error)
|
||||
}
|
||||
132
pkg/lmcli/store.go
Normal file
132
pkg/lmcli/store.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package lmcli
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
sqids "github.com/sqids/sqids-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ConversationStore interface {
|
||||
Conversations() ([]model.Conversation, error)
|
||||
|
||||
ConversationByShortName(shortName string) (*model.Conversation, error)
|
||||
ConversationShortNameCompletions(search string) []string
|
||||
|
||||
SaveConversation(conversation *model.Conversation) error
|
||||
DeleteConversation(conversation *model.Conversation) error
|
||||
|
||||
Messages(conversation *model.Conversation) ([]model.Message, error)
|
||||
LastMessage(conversation *model.Conversation) (*model.Message, error)
|
||||
|
||||
SaveMessage(message *model.Message) error
|
||||
DeleteMessage(message *model.Message) error
|
||||
UpdateMessage(message *model.Message) error
|
||||
AddReply(conversation *model.Conversation, message model.Message) (*model.Message, error)
|
||||
}
|
||||
|
||||
type SQLStore struct {
|
||||
db *gorm.DB
|
||||
sqids *sqids.Sqids
|
||||
}
|
||||
|
||||
func NewSQLStore(db *gorm.DB) (*SQLStore, error) {
|
||||
models := []any{
|
||||
&model.Conversation{},
|
||||
&model.Message{},
|
||||
}
|
||||
|
||||
for _, x := range models {
|
||||
err := db.AutoMigrate(x)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not perform database migrations: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
_sqids, _ := sqids.New(sqids.Options{MinLength: 4})
|
||||
return &SQLStore{db, _sqids}, nil
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveConversation(conversation *model.Conversation) error {
|
||||
err := s.db.Save(&conversation).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !conversation.ShortName.Valid {
|
||||
shortName, _ := s.sqids.Encode([]uint64{uint64(conversation.ID)})
|
||||
conversation.ShortName = sql.NullString{String: shortName, Valid: true}
|
||||
err = s.db.Save(&conversation).Error
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteConversation(conversation *model.Conversation) error {
|
||||
s.db.Where("conversation_id = ?", conversation.ID).Delete(&model.Message{})
|
||||
return s.db.Delete(&conversation).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) SaveMessage(message *model.Message) error {
|
||||
return s.db.Create(message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) DeleteMessage(message *model.Message) error {
|
||||
return s.db.Delete(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) UpdateMessage(message *model.Message) error {
|
||||
return s.db.Updates(&message).Error
|
||||
}
|
||||
|
||||
func (s *SQLStore) Conversations() ([]model.Conversation, error) {
|
||||
var conversations []model.Conversation
|
||||
err := s.db.Find(&conversations).Error
|
||||
return conversations, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationShortNameCompletions(shortName string) []string {
|
||||
var completions []string
|
||||
conversations, _ := s.Conversations() // ignore error for completions
|
||||
for _, conversation := range conversations {
|
||||
if shortName == "" || strings.HasPrefix(conversation.ShortName.String, shortName) {
|
||||
completions = append(completions, fmt.Sprintf("%s\t%s", conversation.ShortName.String, conversation.Title))
|
||||
}
|
||||
}
|
||||
return completions
|
||||
}
|
||||
|
||||
func (s *SQLStore) ConversationByShortName(shortName string) (*model.Conversation, error) {
|
||||
if shortName == "" {
|
||||
return nil, errors.New("shortName is empty")
|
||||
}
|
||||
var conversation model.Conversation
|
||||
err := s.db.Where("short_name = ?", shortName).Find(&conversation).Error
|
||||
return &conversation, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) Messages(conversation *model.Conversation) ([]model.Message, error) {
|
||||
var messages []model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Find(&messages).Error
|
||||
return messages, err
|
||||
}
|
||||
|
||||
func (s *SQLStore) LastMessage(conversation *model.Conversation) (*model.Message, error) {
|
||||
var message model.Message
|
||||
err := s.db.Where("conversation_id = ?", conversation.ID).Last(&message).Error
|
||||
return &message, err
|
||||
}
|
||||
|
||||
// AddReply adds the given messages as a reply to the given conversation, can be
|
||||
// used to easily copy a message associated with one conversation, to another
|
||||
func (s *SQLStore) AddReply(c *model.Conversation, m model.Message) (*model.Message, error) {
|
||||
m.ConversationID = c.ID
|
||||
m.ID = 0
|
||||
m.CreatedAt = time.Time{}
|
||||
return &m, s.SaveMessage(&m)
|
||||
}
|
||||
114
pkg/lmcli/tools/file_insert_lines.go
Normal file
114
pkg/lmcli/tools/file_insert_lines.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
const FILE_INSERT_LINES_DESCRIPTION = `Insert lines into a file, must specify path.
|
||||
|
||||
Make sure your inserts match the flow and indentation of surrounding content.`
|
||||
|
||||
var FileInsertLinesTool = model.Tool{
|
||||
Name: "file_insert_lines",
|
||||
Description: FILE_INSERT_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "position",
|
||||
Type: "integer",
|
||||
Description: `Which line to insert content *before*.`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `The content to insert.`,
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var position int
|
||||
tmp, ok = args["position"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid position in function arguments: %v", tmp)
|
||||
}
|
||||
position = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileInsertLines(path, position, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileInsertLines(path string, position int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if position < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
before := lines[:position-1]
|
||||
after := lines[position-1:]
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
||||
133
pkg/lmcli/tools/file_replace_lines.go
Normal file
133
pkg/lmcli/tools/file_replace_lines.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const FILE_REPLACE_LINES_DESCRIPTION = `Replace or remove a range of lines within a file, must specify path.
|
||||
|
||||
Useful for re-writing snippets/blocks of code or entire functions.
|
||||
|
||||
Plan your edits carefully and ensure any new content matches the flow and indentation of surrounding text.`
|
||||
|
||||
var FileReplaceLinesTool = model.Tool{
|
||||
Name: "file_replace_lines",
|
||||
Description: FILE_REPLACE_LINES_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path of the file to be modified, relative to the current working directory.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "start_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the start of the replacement range (inclusive).`,
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "end_line",
|
||||
Type: "integer",
|
||||
Description: `Line number which specifies the end of the replacement range (inclusive). If unset, range extends to end of file.`,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: `Content to replace specified range. Omit to remove the specified range.`,
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
var start_line int
|
||||
tmp, ok = args["start_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid start_line in function arguments: %v", tmp)
|
||||
}
|
||||
start_line = int(tmp)
|
||||
}
|
||||
var end_line int
|
||||
tmp, ok = args["end_line"]
|
||||
if ok {
|
||||
tmp, ok := tmp.(float64)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid end_line in function arguments: %v", tmp)
|
||||
}
|
||||
end_line = int(tmp)
|
||||
}
|
||||
var content string
|
||||
tmp, ok = args["content"]
|
||||
if ok {
|
||||
content, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
|
||||
result := fileReplaceLines(path, start_line, end_line, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func fileReplaceLines(path string, startLine int, endLine int, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
// Read the existing file's content
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
_, err = os.Create(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not create new file: %s", err.Error())}
|
||||
}
|
||||
data = []byte{}
|
||||
}
|
||||
|
||||
if startLine < 1 {
|
||||
return model.CallResult{Message: "start_line cannot be less than 1"}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
contentLines := strings.Split(strings.Trim(content, "\n"), "\n")
|
||||
|
||||
if endLine == 0 || endLine > len(lines) {
|
||||
endLine = len(lines)
|
||||
}
|
||||
|
||||
before := lines[:startLine-1]
|
||||
after := lines[endLine:]
|
||||
|
||||
lines = append(before, append(contentLines, after...)...)
|
||||
newContent := strings.Join(lines, "\n")
|
||||
|
||||
// Join the lines and write back to the file
|
||||
err = os.WriteFile(path, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
|
||||
return model.CallResult{Result: newContent}
|
||||
}
|
||||
100
pkg/lmcli/tools/read_dir.go
Normal file
100
pkg/lmcli/tools/read_dir.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_DIR_DESCRIPTION = `Return the contents of the CWD (current working directory).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": [
|
||||
{"name": "a_file.txt", "type": "file", "size": 123},
|
||||
{"name": "a_directory/", "type": "dir", "size": 11},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
For files, size represents the size of the file, in bytes.
|
||||
For directories, size represents the number of entries in that directory.`
|
||||
|
||||
var ReadDirTool = model.Tool{
|
||||
Name: "read_dir",
|
||||
Description: READ_DIR_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "relative_dir",
|
||||
Type: "string",
|
||||
Description: "If set, read the contents of a directory relative to the current one.",
|
||||
},
|
||||
},
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
var relativeDir string
|
||||
tmp, ok := args["relative_dir"]
|
||||
if ok {
|
||||
relativeDir, ok = tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid relative_dir in function arguments: %v", tmp)
|
||||
}
|
||||
}
|
||||
result := readDir(relativeDir)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readDir(path string) model.CallResult {
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
|
||||
files, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return model.CallResult{
|
||||
Message: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
var dirContents []map[string]interface{}
|
||||
for _, f := range files {
|
||||
info, _ := f.Info()
|
||||
|
||||
name := f.Name()
|
||||
if strings.HasPrefix(name, ".") {
|
||||
// skip hidden files
|
||||
continue
|
||||
}
|
||||
|
||||
entryType := "file"
|
||||
size := info.Size()
|
||||
|
||||
if info.IsDir() {
|
||||
name += "/"
|
||||
entryType = "dir"
|
||||
subdirfiles, _ := os.ReadDir(filepath.Join(".", path, info.Name()))
|
||||
size = int64(len(subdirfiles))
|
||||
}
|
||||
|
||||
dirContents = append(dirContents, map[string]interface{}{
|
||||
"name": name,
|
||||
"type": entryType,
|
||||
"size": size,
|
||||
})
|
||||
}
|
||||
|
||||
return model.CallResult{Result: dirContents}
|
||||
}
|
||||
71
pkg/lmcli/tools/read_file.go
Normal file
71
pkg/lmcli/tools/read_file.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const READ_FILE_DESCRIPTION = `Read the contents of a text file relative to the current working directory.
|
||||
|
||||
Each line of the returned content is prefixed with its line number and a tab (\t).
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success",
|
||||
"result": "1\tthe contents\n2\tof the file\n"
|
||||
}`
|
||||
|
||||
var ReadFileTool = model.Tool{
|
||||
Name: "read_file",
|
||||
Description: READ_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to read.",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
|
||||
Impl: func(tool *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to read_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
result := readFile(path)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func readFile(path string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not read path: %s", err.Error())}
|
||||
}
|
||||
|
||||
lines := strings.Split(string(data), "\n")
|
||||
content := strings.Builder{}
|
||||
for i, line := range lines {
|
||||
content.WriteString(fmt.Sprintf("%d\t%s\n", i+1, line))
|
||||
}
|
||||
|
||||
return model.CallResult{
|
||||
Result: content.String(),
|
||||
}
|
||||
}
|
||||
47
pkg/lmcli/tools/tools.go
Normal file
47
pkg/lmcli/tools/tools.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
)
|
||||
|
||||
var AvailableTools map[string]model.Tool = map[string]model.Tool{
|
||||
"read_dir": ReadDirTool,
|
||||
"read_file": ReadFileTool,
|
||||
"write_file": WriteFileTool,
|
||||
"file_insert_lines": FileInsertLinesTool,
|
||||
"file_replace_lines": FileReplaceLinesTool,
|
||||
}
|
||||
|
||||
func ExecuteToolCalls(toolCalls []model.ToolCall, toolBag []model.Tool) ([]model.ToolResult, error) {
|
||||
var toolResults []model.ToolResult
|
||||
for _, toolCall := range toolCalls {
|
||||
var tool *model.Tool
|
||||
for _, available := range toolBag {
|
||||
if available.Name == toolCall.Name {
|
||||
tool = &available
|
||||
break
|
||||
}
|
||||
}
|
||||
if tool == nil {
|
||||
return nil, fmt.Errorf("Requested tool '%s' does not exist. Hallucination?", toolCall.Name)
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
result, err := tool.Impl(tool, toolCall.Parameters)
|
||||
if err != nil {
|
||||
// This can happen if the model missed or supplied invalid tool args
|
||||
return nil, fmt.Errorf("Tool '%s' error: %v\n", toolCall.Name, err)
|
||||
}
|
||||
|
||||
toolResult := model.ToolResult{
|
||||
ToolCallID: toolCall.ID,
|
||||
ToolName: toolCall.Name,
|
||||
Result: result,
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
return toolResults, nil
|
||||
}
|
||||
67
pkg/lmcli/tools/util/util.go
Normal file
67
pkg/lmcli/tools/util/util.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// isPathContained attempts to verify whether `path` is the same as or
|
||||
// contained within `directory`. It is overly cautious, returning false even if
|
||||
// `path` IS contained within `directory`, but the two paths use different
|
||||
// casing, and we happen to be on a case-insensitive filesystem.
|
||||
// This is ultimately to attempt to stop an LLM from going outside of where I
|
||||
// tell it to. Additional layers of security should be considered.. run in a
|
||||
// VM/container.
|
||||
func IsPathContained(directory string, path string) (bool, error) {
|
||||
// Clean and resolve symlinks for both paths
|
||||
path, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// check if path exists
|
||||
_, err = os.Stat(path)
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("Could not stat path: %v", err)
|
||||
}
|
||||
} else {
|
||||
path, err = filepath.EvalSymlinks(path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
directory, err = filepath.Abs(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
directory, err = filepath.EvalSymlinks(directory)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Case insensitive checks
|
||||
if !strings.EqualFold(path, directory) &&
|
||||
!strings.HasPrefix(strings.ToLower(path), strings.ToLower(directory)+string(os.PathSeparator)) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func IsPathWithinCWD(path string) (bool, string) {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false, "Failed to determine current working directory"
|
||||
}
|
||||
if ok, err := IsPathContained(cwd, path); !ok {
|
||||
if err != nil {
|
||||
return false, fmt.Sprintf("Could not determine whether path '%s' is within the current working directory: %s", path, err.Error())
|
||||
}
|
||||
return false, fmt.Sprintf("Path '%s' is not within the current working directory", path)
|
||||
}
|
||||
return true, ""
|
||||
}
|
||||
71
pkg/lmcli/tools/write_file.go
Normal file
71
pkg/lmcli/tools/write_file.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
toolutil "git.mlow.ca/mlow/lmcli/pkg/lmcli/tools/util"
|
||||
)
|
||||
|
||||
const WRITE_FILE_DESCRIPTION = `Write the provided contents to a file relative to the current working directory.
|
||||
|
||||
Example result:
|
||||
{
|
||||
"message": "success"
|
||||
}`
|
||||
|
||||
var WriteFileTool = model.Tool{
|
||||
Name: "write_file",
|
||||
Description: WRITE_FILE_DESCRIPTION,
|
||||
Parameters: []model.ToolParameter{
|
||||
{
|
||||
Name: "path",
|
||||
Type: "string",
|
||||
Description: "Path to a file within the current working directory to write to.",
|
||||
Required: true,
|
||||
},
|
||||
{
|
||||
Name: "content",
|
||||
Type: "string",
|
||||
Description: "The content to write to the file. Overwrites any existing content!",
|
||||
Required: true,
|
||||
},
|
||||
},
|
||||
Impl: func(t *model.Tool, args map[string]interface{}) (string, error) {
|
||||
tmp, ok := args["path"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Path parameter to write_file was not included.")
|
||||
}
|
||||
path, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid path in function arguments: %v", tmp)
|
||||
}
|
||||
tmp, ok = args["content"]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Content parameter to write_file was not included.")
|
||||
}
|
||||
content, ok := tmp.(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("Invalid content in function arguments: %v", tmp)
|
||||
}
|
||||
result := writeFile(path, content)
|
||||
ret, err := result.ToJson()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Could not serialize result: %v", err)
|
||||
}
|
||||
return ret, nil
|
||||
},
|
||||
}
|
||||
|
||||
func writeFile(path string, content string) model.CallResult {
|
||||
ok, reason := toolutil.IsPathWithinCWD(path)
|
||||
if !ok {
|
||||
return model.CallResult{Message: reason}
|
||||
}
|
||||
err := os.WriteFile(path, []byte(content), 0644)
|
||||
if err != nil {
|
||||
return model.CallResult{Message: fmt.Sprintf("Could not write to path: %s", err.Error())}
|
||||
}
|
||||
return model.CallResult{}
|
||||
}
|
||||
846
pkg/tui/tui.go
Normal file
846
pkg/tui/tui.go
Normal file
@@ -0,0 +1,846 @@
|
||||
package tui
|
||||
|
||||
// The terminal UI for lmcli, launched from the `lmcli chat` command
|
||||
// TODO:
|
||||
// - conversation list view
|
||||
// - change model
|
||||
// - rename conversation
|
||||
// - set system prompt
|
||||
// - system prompt library?
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
cmdutil "git.mlow.ca/mlow/lmcli/pkg/cmd/util"
|
||||
"git.mlow.ca/mlow/lmcli/pkg/lmcli"
|
||||
models "git.mlow.ca/mlow/lmcli/pkg/lmcli/model"
|
||||
"github.com/charmbracelet/bubbles/spinner"
|
||||
"github.com/charmbracelet/bubbles/textarea"
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/muesli/reflow/wordwrap"
|
||||
)
|
||||
|
||||
type focusState int
|
||||
|
||||
const (
|
||||
focusInput focusState = iota
|
||||
focusMessages
|
||||
)
|
||||
|
||||
type editorTarget int
|
||||
|
||||
const (
|
||||
input editorTarget = iota
|
||||
selectedMessage
|
||||
)
|
||||
|
||||
type model struct {
|
||||
width int
|
||||
height int
|
||||
|
||||
ctx *lmcli.Context
|
||||
convShortname string
|
||||
|
||||
// application state
|
||||
conversation *models.Conversation
|
||||
messages []models.Message
|
||||
waitingForReply bool
|
||||
editorTarget editorTarget
|
||||
stopSignal chan interface{}
|
||||
replyChan chan models.Message
|
||||
replyChunkChan chan string
|
||||
persistence bool // whether we will save new messages in the conversation
|
||||
err error
|
||||
|
||||
// ui state
|
||||
focus focusState
|
||||
wrap bool // whether message content is wrapped to viewport width
|
||||
status string // a general status message
|
||||
highlightCache []string // a cache of syntax highlighted message content
|
||||
messageOffsets []int
|
||||
selectedMessage int
|
||||
|
||||
// ui elements
|
||||
content viewport.Model
|
||||
input textarea.Model
|
||||
spinner spinner.Model
|
||||
}
|
||||
|
||||
type message struct {
|
||||
role string
|
||||
content string
|
||||
}
|
||||
|
||||
// custom tea.Msg types
|
||||
type (
|
||||
// sent on each chunk received from LLM
|
||||
msgResponseChunk string
|
||||
// sent when response is finished being received
|
||||
msgResponseEnd string
|
||||
// a special case of msgError that stops the response waiting animation
|
||||
msgResponseError error
|
||||
// sent on each completed reply
|
||||
msgAssistantReply models.Message
|
||||
// sent when a conversation is (re)loaded
|
||||
msgConversationLoaded *models.Conversation
|
||||
// sent when a new conversation title is set
|
||||
msgConversationTitleChanged string
|
||||
// send when a conversation's messages are laoded
|
||||
msgMessagesLoaded []models.Message
|
||||
// sent when an error occurs
|
||||
msgError error
|
||||
)
|
||||
|
||||
// styles
|
||||
var (
|
||||
userStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("10"))
|
||||
assistantStyle = lipgloss.NewStyle().Faint(true).Bold(true).Foreground(lipgloss.Color("12"))
|
||||
messageStyle = lipgloss.NewStyle().PaddingLeft(2).PaddingRight(2)
|
||||
headerStyle = lipgloss.NewStyle().
|
||||
Background(lipgloss.Color("0"))
|
||||
conversationStyle = lipgloss.NewStyle().
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
footerStyle = lipgloss.NewStyle().
|
||||
BorderTop(true).
|
||||
BorderStyle(lipgloss.NormalBorder())
|
||||
)
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
textarea.Blink,
|
||||
m.spinner.Tick,
|
||||
m.loadConversation(m.convShortname),
|
||||
m.waitForChunk(),
|
||||
m.waitForReply(),
|
||||
)
|
||||
}
|
||||
|
||||
func wrapError(err error) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return msgError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case msgTempfileEditorClosed:
|
||||
contents := string(msg)
|
||||
switch m.editorTarget {
|
||||
case input:
|
||||
m.input.SetValue(contents)
|
||||
case selectedMessage:
|
||||
m.setMessageContents(m.selectedMessage, contents)
|
||||
if m.persistence && m.messages[m.selectedMessage].ID > 0 {
|
||||
// update persisted message
|
||||
err := m.ctx.Store.UpdateMessage(&m.messages[m.selectedMessage])
|
||||
if err != nil {
|
||||
cmds = append(cmds, wrapError(fmt.Errorf("Could not save edited message: %v", err)))
|
||||
}
|
||||
}
|
||||
m.updateContent()
|
||||
}
|
||||
case tea.KeyMsg:
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
if m.waitingForReply {
|
||||
m.stopSignal <- ""
|
||||
} else {
|
||||
return m, tea.Quit
|
||||
}
|
||||
case "ctrl+p":
|
||||
m.persistence = !m.persistence
|
||||
case "ctrl+w":
|
||||
m.wrap = !m.wrap
|
||||
m.updateContent()
|
||||
case "q":
|
||||
if m.focus != focusInput {
|
||||
return m, tea.Quit
|
||||
}
|
||||
default:
|
||||
var inputHandled tea.Cmd
|
||||
switch m.focus {
|
||||
case focusInput:
|
||||
inputHandled = m.handleInputKey(msg)
|
||||
case focusMessages:
|
||||
inputHandled = m.handleMessagesKey(msg)
|
||||
}
|
||||
if inputHandled != nil {
|
||||
return m, inputHandled
|
||||
}
|
||||
}
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.content.Width = msg.Width
|
||||
m.content.Height = msg.Height - m.getFixedComponentHeight()
|
||||
m.input.SetWidth(msg.Width - 1)
|
||||
m.updateContent()
|
||||
case msgConversationLoaded:
|
||||
m.conversation = (*models.Conversation)(msg)
|
||||
cmds = append(cmds, m.loadMessages(m.conversation))
|
||||
case msgMessagesLoaded:
|
||||
m.setMessages(msg)
|
||||
m.updateContent()
|
||||
case msgResponseChunk:
|
||||
chunk := string(msg)
|
||||
last := len(m.messages) - 1
|
||||
if last >= 0 && m.messages[last].Role == models.MessageRoleAssistant {
|
||||
m.setMessageContents(last, m.messages[last].Content+chunk)
|
||||
} else {
|
||||
m.addMessage(models.Message{
|
||||
Role: models.MessageRoleAssistant,
|
||||
Content: chunk,
|
||||
})
|
||||
}
|
||||
m.updateContent()
|
||||
cmds = append(cmds, m.waitForChunk()) // wait for the next chunk
|
||||
case msgAssistantReply:
|
||||
// the last reply that was being worked on is finished
|
||||
reply := models.Message(msg)
|
||||
last := len(m.messages) - 1
|
||||
if last < 0 {
|
||||
panic("Unexpected empty messages handling msgReply")
|
||||
}
|
||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
||||
if m.messages[last].Role == models.MessageRoleAssistant {
|
||||
// the last message was an assistant message, so this is a continuation
|
||||
if reply.Role == models.MessageRoleToolCall {
|
||||
// update last message rrole to tool call
|
||||
m.messages[last].Role = models.MessageRoleToolCall
|
||||
}
|
||||
} else {
|
||||
m.addMessage(reply)
|
||||
}
|
||||
|
||||
if m.persistence {
|
||||
var err error
|
||||
if m.conversation.ID == 0 {
|
||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||
}
|
||||
if err != nil {
|
||||
cmds = append(cmds, wrapError(err))
|
||||
} else {
|
||||
cmds = append(cmds, m.persistConversation())
|
||||
}
|
||||
}
|
||||
|
||||
if m.conversation.Title == "" {
|
||||
cmds = append(cmds, m.generateConversationTitle())
|
||||
}
|
||||
|
||||
m.updateContent()
|
||||
cmds = append(cmds, m.waitForReply())
|
||||
case msgResponseEnd:
|
||||
m.waitingForReply = false
|
||||
last := len(m.messages) - 1
|
||||
if last < 0 {
|
||||
panic("Unexpected empty messages handling msgResponseEnd")
|
||||
}
|
||||
m.setMessageContents(last, strings.TrimSpace(m.messages[last].Content))
|
||||
m.updateContent()
|
||||
m.status = "Press ctrl+s to send"
|
||||
case msgResponseError:
|
||||
m.waitingForReply = false
|
||||
m.status = "Press ctrl+s to send"
|
||||
m.err = error(msg)
|
||||
case msgConversationTitleChanged:
|
||||
title := string(msg)
|
||||
m.conversation.Title = title
|
||||
if m.persistence {
|
||||
err := m.ctx.Store.SaveConversation(m.conversation)
|
||||
if err != nil {
|
||||
cmds = append(cmds, wrapError(err))
|
||||
}
|
||||
}
|
||||
case msgError:
|
||||
m.err = error(msg)
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.spinner, cmd = m.spinner.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
inputCaptured := false
|
||||
m.input, cmd = m.input.Update(msg)
|
||||
if cmd != nil {
|
||||
inputCaptured = true
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
if !inputCaptured {
|
||||
m.content, cmd = m.content.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
if m.width == 0 {
|
||||
// this is the case upon initial startup, but it's also a safe bet that
|
||||
// we can just skip rendering if the terminal is really 0 width...
|
||||
// without this, the m.*View() functions may crash
|
||||
return ""
|
||||
}
|
||||
|
||||
sections := make([]string, 0, 6)
|
||||
sections = append(sections, m.headerView())
|
||||
sections = append(sections, m.contentView())
|
||||
error := m.errorView()
|
||||
if error != "" {
|
||||
sections = append(sections, error)
|
||||
}
|
||||
sections = append(sections, m.inputView())
|
||||
sections = append(sections, m.footerView())
|
||||
|
||||
return lipgloss.JoinVertical(
|
||||
lipgloss.Left,
|
||||
sections...,
|
||||
)
|
||||
}
|
||||
|
||||
// returns the total height of "fixed" components, which are those which don't
|
||||
// change height dependent on window size.
|
||||
func (m *model) getFixedComponentHeight() int {
|
||||
h := 0
|
||||
h += m.input.Height()
|
||||
h += lipgloss.Height(m.headerView())
|
||||
h += lipgloss.Height(m.footerView())
|
||||
errorView := m.errorView()
|
||||
if errorView != "" {
|
||||
h += lipgloss.Height(errorView)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (m *model) headerView() string {
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
PaddingLeft(1).
|
||||
PaddingRight(1).
|
||||
Bold(true)
|
||||
var title string
|
||||
if m.conversation != nil && m.conversation.Title != "" {
|
||||
title = m.conversation.Title
|
||||
} else {
|
||||
title = "Untitled"
|
||||
}
|
||||
part := titleStyle.Render(title)
|
||||
|
||||
return headerStyle.Width(m.width).Render(part)
|
||||
}
|
||||
|
||||
func (m *model) contentView() string {
|
||||
return m.content.View()
|
||||
}
|
||||
|
||||
func (m *model) errorView() string {
|
||||
if m.err == nil {
|
||||
return ""
|
||||
}
|
||||
return lipgloss.NewStyle().
|
||||
Width(m.width).
|
||||
AlignHorizontal(lipgloss.Center).
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("1")).
|
||||
Render(fmt.Sprintf("%s", m.err))
|
||||
}
|
||||
|
||||
func (m *model) inputView() string {
|
||||
return m.input.View()
|
||||
}
|
||||
|
||||
func (m *model) footerView() string {
|
||||
segmentStyle := lipgloss.NewStyle().PaddingLeft(1).PaddingRight(1).Faint(true)
|
||||
segmentSeparator := "|"
|
||||
|
||||
savingStyle := segmentStyle.Copy().Bold(true)
|
||||
saving := ""
|
||||
if m.persistence {
|
||||
saving = savingStyle.Foreground(lipgloss.Color("2")).Render("✅💾")
|
||||
} else {
|
||||
saving = savingStyle.Foreground(lipgloss.Color("1")).Render("❌💾")
|
||||
}
|
||||
|
||||
status := m.status
|
||||
if m.waitingForReply {
|
||||
status += m.spinner.View()
|
||||
}
|
||||
|
||||
leftSegments := []string{
|
||||
saving,
|
||||
segmentStyle.Render(status),
|
||||
}
|
||||
rightSegments := []string{
|
||||
segmentStyle.Render(fmt.Sprintf("Model: %s", *m.ctx.Config.Defaults.Model)),
|
||||
}
|
||||
|
||||
left := strings.Join(leftSegments, segmentSeparator)
|
||||
right := strings.Join(rightSegments, segmentSeparator)
|
||||
|
||||
totalWidth := lipgloss.Width(left) + lipgloss.Width(right)
|
||||
remaining := m.width - totalWidth
|
||||
|
||||
var padding string
|
||||
if remaining > 0 {
|
||||
padding = strings.Repeat(" ", remaining)
|
||||
}
|
||||
|
||||
footer := left + padding + right
|
||||
if remaining < 0 {
|
||||
ellipses := "... "
|
||||
// this doesn't work very well, due to trying to trim a string with
|
||||
// ansii chars already in it
|
||||
footer = footer[:(len(footer)+remaining)-len(ellipses)-3] + ellipses
|
||||
}
|
||||
return footerStyle.Width(m.width).Render(footer)
|
||||
}
|
||||
|
||||
func initialModel(ctx *lmcli.Context, convShortname string) model {
|
||||
m := model{
|
||||
ctx: ctx,
|
||||
convShortname: convShortname,
|
||||
conversation: &models.Conversation{},
|
||||
persistence: true,
|
||||
|
||||
stopSignal: make(chan interface{}),
|
||||
replyChan: make(chan models.Message),
|
||||
replyChunkChan: make(chan string),
|
||||
|
||||
wrap: true,
|
||||
selectedMessage: -1,
|
||||
}
|
||||
|
||||
m.content = viewport.New(0, 0)
|
||||
|
||||
m.input = textarea.New()
|
||||
m.input.CharLimit = 0
|
||||
m.input.Placeholder = "Enter a message"
|
||||
|
||||
m.input.FocusedStyle.CursorLine = lipgloss.NewStyle()
|
||||
m.input.ShowLineNumbers = false
|
||||
m.input.SetHeight(4)
|
||||
m.input.Focus()
|
||||
|
||||
m.spinner = spinner.New(spinner.WithSpinner(
|
||||
spinner.Spinner{
|
||||
Frames: []string{
|
||||
". ",
|
||||
".. ",
|
||||
"...",
|
||||
".. ",
|
||||
". ",
|
||||
" ",
|
||||
},
|
||||
FPS: time.Second / 3,
|
||||
},
|
||||
))
|
||||
|
||||
m.waitingForReply = false
|
||||
m.status = "Press ctrl+s to send"
|
||||
return m
|
||||
}
|
||||
|
||||
// fraction is the fraction of the total screen height into view the offset
|
||||
// should be scrolled into view. 0.5 = items will be snapped to middle of
|
||||
// view
|
||||
func scrollIntoView(vp *viewport.Model, offset int, fraction float32) {
|
||||
currentOffset := vp.YOffset
|
||||
if offset >= currentOffset && offset < currentOffset+vp.Height {
|
||||
return
|
||||
}
|
||||
distance := currentOffset - offset
|
||||
if distance < 0 {
|
||||
// we should scroll down until it just comes into view
|
||||
vp.SetYOffset(currentOffset - (distance + (vp.Height - int(float32(vp.Height)*fraction))) + 1)
|
||||
} else {
|
||||
// we should scroll up
|
||||
vp.SetYOffset(currentOffset - distance - int(float32(vp.Height)*fraction))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) handleMessagesKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "tab":
|
||||
m.focus = focusInput
|
||||
m.updateContent()
|
||||
m.input.Focus()
|
||||
case "e":
|
||||
message := m.messages[m.selectedMessage]
|
||||
cmd := openTempfileEditor("message.*.md", message.Content, "# Edit the message below\n")
|
||||
m.editorTarget = selectedMessage
|
||||
return cmd
|
||||
case "ctrl+k":
|
||||
if m.selectedMessage > 0 && len(m.messages) == len(m.messageOffsets) {
|
||||
m.selectedMessage--
|
||||
m.updateContent()
|
||||
offset := m.messageOffsets[m.selectedMessage]
|
||||
scrollIntoView(&m.content, offset, 0.1)
|
||||
}
|
||||
case "ctrl+j":
|
||||
if m.selectedMessage < len(m.messages)-1 && len(m.messages) == len(m.messageOffsets) {
|
||||
m.selectedMessage++
|
||||
m.updateContent()
|
||||
offset := m.messageOffsets[m.selectedMessage]
|
||||
scrollIntoView(&m.content, offset, 0.1)
|
||||
}
|
||||
case "ctrl+r":
|
||||
// resubmit the conversation with all messages up until and including
|
||||
// the selected message
|
||||
if len(m.messages) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.messages = m.messages[:m.selectedMessage+1]
|
||||
m.highlightCache = m.highlightCache[:m.selectedMessage+1]
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
return m.promptLLM()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) handleInputKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.focus = focusMessages
|
||||
if m.selectedMessage < 0 || m.selectedMessage >= len(m.messages) {
|
||||
m.selectedMessage = len(m.messages) - 1
|
||||
}
|
||||
m.updateContent()
|
||||
m.input.Blur()
|
||||
case "ctrl+s":
|
||||
userInput := strings.TrimSpace(m.input.Value())
|
||||
if strings.TrimSpace(userInput) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(m.messages) > 0 && m.messages[len(m.messages)-1].Role == models.MessageRoleUser {
|
||||
return wrapError(fmt.Errorf("Can't reply to a user message"))
|
||||
}
|
||||
|
||||
reply := models.Message{
|
||||
Role: models.MessageRoleUser,
|
||||
Content: userInput,
|
||||
}
|
||||
|
||||
if m.persistence {
|
||||
var err error
|
||||
if m.conversation.ID == 0 {
|
||||
err = m.ctx.Store.SaveConversation(m.conversation)
|
||||
}
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// ensure all messages up to the one we're about to add are
|
||||
// persistent
|
||||
cmd := m.persistConversation()
|
||||
if cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
// persist our new message, returning with any possible errors
|
||||
savedReply, err := m.ctx.Store.AddReply(m.conversation, reply)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
reply = *savedReply
|
||||
}
|
||||
|
||||
m.input.SetValue("")
|
||||
m.addMessage(reply)
|
||||
|
||||
m.updateContent()
|
||||
m.content.GotoBottom()
|
||||
return m.promptLLM()
|
||||
case "ctrl+e":
|
||||
cmd := openTempfileEditor("message.*.md", m.input.Value(), "# Edit your input below\n")
|
||||
m.editorTarget = input
|
||||
return cmd
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) loadConversation(shortname string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if shortname == "" {
|
||||
return nil
|
||||
}
|
||||
c, err := m.ctx.Store.ConversationByShortName(shortname)
|
||||
if err != nil {
|
||||
return msgError(fmt.Errorf("Could not lookup conversation: %v", err))
|
||||
}
|
||||
if c.ID == 0 {
|
||||
return msgError(fmt.Errorf("Conversation not found: %s", shortname))
|
||||
}
|
||||
return msgConversationLoaded(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) loadMessages(c *models.Conversation) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
messages, err := m.ctx.Store.Messages(c)
|
||||
if err != nil {
|
||||
return msgError(fmt.Errorf("Could not load conversation messages: %v\n", err))
|
||||
}
|
||||
return msgMessagesLoaded(messages)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) waitForReply() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return msgAssistantReply(<-m.replyChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) waitForChunk() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
return msgResponseChunk(<-m.replyChunkChan)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) generateConversationTitle() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
title, err := cmdutil.GenerateTitle(m.ctx, m.conversation)
|
||||
if err != nil {
|
||||
return msgError(err)
|
||||
}
|
||||
return msgConversationTitleChanged(title)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) promptLLM() tea.Cmd {
|
||||
m.waitingForReply = true
|
||||
m.status = "Press ctrl+c to cancel"
|
||||
|
||||
return func() tea.Msg {
|
||||
completionProvider, err := m.ctx.GetCompletionProvider(*m.ctx.Config.Defaults.Model)
|
||||
if err != nil {
|
||||
return msgError(err)
|
||||
}
|
||||
|
||||
requestParams := models.RequestParameters{
|
||||
Model: *m.ctx.Config.Defaults.Model,
|
||||
MaxTokens: *m.ctx.Config.Defaults.MaxTokens,
|
||||
Temperature: *m.ctx.Config.Defaults.Temperature,
|
||||
ToolBag: m.ctx.EnabledTools,
|
||||
}
|
||||
|
||||
replyHandler := func(msg models.Message) {
|
||||
m.replyChan <- msg
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
canceled := false
|
||||
go func() {
|
||||
select {
|
||||
case <-m.stopSignal:
|
||||
canceled = true
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
resp, err := completionProvider.CreateChatCompletionStream(
|
||||
ctx, requestParams, m.messages, replyHandler, m.replyChunkChan,
|
||||
)
|
||||
|
||||
if err != nil && !canceled {
|
||||
return msgResponseError(err)
|
||||
}
|
||||
|
||||
return msgResponseEnd(resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) persistConversation() tea.Cmd {
|
||||
existingMessages, err := m.ctx.Store.Messages(m.conversation)
|
||||
if err != nil {
|
||||
return wrapError(fmt.Errorf("Could not retrieve existing conversation messages while trying to save: %v", err))
|
||||
}
|
||||
|
||||
existingById := make(map[uint]*models.Message, len(existingMessages))
|
||||
for _, msg := range existingMessages {
|
||||
existingById[msg.ID] = &msg
|
||||
}
|
||||
|
||||
currentById := make(map[uint]*models.Message, len(m.messages))
|
||||
for _, msg := range m.messages {
|
||||
currentById[msg.ID] = &msg
|
||||
}
|
||||
|
||||
for _, msg := range existingMessages {
|
||||
_, ok := currentById[msg.ID]
|
||||
if !ok {
|
||||
err := m.ctx.Store.DeleteMessage(&msg)
|
||||
if err != nil {
|
||||
return wrapError(fmt.Errorf("Failed to remove messages: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, msg := range m.messages {
|
||||
if msg.ID > 0 {
|
||||
exist, ok := existingById[msg.ID]
|
||||
if ok {
|
||||
if msg.Content == exist.Content {
|
||||
continue
|
||||
}
|
||||
// update message when contents don't match that of store
|
||||
err := m.ctx.Store.UpdateMessage(&msg)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
} else {
|
||||
// this would be quite odd... and I'm not sure how to handle
|
||||
// it at the time of writing this
|
||||
}
|
||||
} else {
|
||||
newMessage, err := m.ctx.Store.AddReply(m.conversation, msg)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
m.setMessage(i, *newMessage)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *model) setMessages(messages []models.Message) {
|
||||
m.messages = messages
|
||||
m.highlightCache = make([]string, len(messages))
|
||||
for i, msg := range m.messages {
|
||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
||||
m.highlightCache[i] = highlighted
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) setMessage(i int, msg models.Message) {
|
||||
if i >= len(m.messages) {
|
||||
panic("i out of range")
|
||||
}
|
||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
||||
m.messages[i] = msg
|
||||
m.highlightCache[i] = highlighted
|
||||
}
|
||||
|
||||
func (m *model) addMessage(msg models.Message) {
|
||||
highlighted, _ := m.ctx.Chroma.HighlightS(msg.Content)
|
||||
m.messages = append(m.messages, msg)
|
||||
m.highlightCache = append(m.highlightCache, highlighted)
|
||||
}
|
||||
|
||||
func (m *model) setMessageContents(i int, content string) {
|
||||
if i >= len(m.messages) {
|
||||
panic("i out of range")
|
||||
}
|
||||
highlighted, _ := m.ctx.Chroma.HighlightS(content)
|
||||
m.messages[i].Content = content
|
||||
m.highlightCache[i] = highlighted
|
||||
}
|
||||
|
||||
func (m *model) updateContent() {
|
||||
atBottom := m.content.AtBottom()
|
||||
m.content.SetContent(m.conversationView())
|
||||
if atBottom {
|
||||
// if we were at bottom before the update, scroll with the output
|
||||
m.content.GotoBottom()
|
||||
}
|
||||
}
|
||||
|
||||
// render the conversation into a string
|
||||
func (m *model) conversationView() string {
|
||||
sb := strings.Builder{}
|
||||
msgCnt := len(m.messages)
|
||||
|
||||
m.messageOffsets = make([]int, len(m.messages))
|
||||
lineCnt := conversationStyle.GetMarginTop()
|
||||
for i, message := range m.messages {
|
||||
m.messageOffsets[i] = lineCnt
|
||||
|
||||
icon := "⚙️"
|
||||
friendly := message.Role.FriendlyRole()
|
||||
style := lipgloss.NewStyle().Bold(true).Faint(true)
|
||||
|
||||
switch message.Role {
|
||||
case models.MessageRoleUser:
|
||||
icon = ""
|
||||
style = userStyle
|
||||
case models.MessageRoleAssistant:
|
||||
icon = ""
|
||||
style = assistantStyle
|
||||
case models.MessageRoleToolCall, models.MessageRoleToolResult:
|
||||
icon = "🔧"
|
||||
}
|
||||
|
||||
// write message heading with space for content
|
||||
user := style.Render(icon + friendly)
|
||||
|
||||
var prefix string
|
||||
var suffix string
|
||||
|
||||
faint := lipgloss.NewStyle().Faint(true)
|
||||
if m.focus == focusMessages {
|
||||
if i == m.selectedMessage {
|
||||
prefix = "> "
|
||||
}
|
||||
suffix += faint.Render(fmt.Sprintf(" (%d/%d)", i+1, msgCnt))
|
||||
}
|
||||
|
||||
if message.ID == 0 {
|
||||
suffix += faint.Render(" (not saved)")
|
||||
}
|
||||
|
||||
header := lipgloss.NewStyle().PaddingLeft(1).Render(prefix + user + suffix)
|
||||
sb.WriteString(header)
|
||||
lineCnt += lipgloss.Height(header)
|
||||
|
||||
// TODO: special rendering for tool calls/results?
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n\n")
|
||||
lineCnt += 1
|
||||
|
||||
// write message contents
|
||||
var highlighted string
|
||||
if m.highlightCache[i] == "" {
|
||||
highlighted = message.Content
|
||||
} else {
|
||||
highlighted = m.highlightCache[i]
|
||||
}
|
||||
var contents string
|
||||
if m.wrap {
|
||||
wrapWidth := m.content.Width - messageStyle.GetHorizontalPadding() - 2
|
||||
wrapped := wordwrap.String(highlighted, wrapWidth)
|
||||
contents = wrapped
|
||||
} else {
|
||||
contents = highlighted
|
||||
}
|
||||
sb.WriteString(messageStyle.Width(0).Render(contents))
|
||||
lineCnt += lipgloss.Height(contents)
|
||||
}
|
||||
|
||||
if i < msgCnt-1 {
|
||||
sb.WriteString("\n\n")
|
||||
lineCnt += 1
|
||||
}
|
||||
}
|
||||
return conversationStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func Launch(ctx *lmcli.Context, convShortname string) error {
|
||||
p := tea.NewProgram(initialModel(ctx, convShortname), tea.WithAltScreen())
|
||||
if _, err := p.Run(); err != nil {
|
||||
return fmt.Errorf("Error running program: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
42
pkg/tui/util.go
Normal file
42
pkg/tui/util.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type msgTempfileEditorClosed string
|
||||
|
||||
// openTempfileEditor opens an $EDITOR on a new temporary file with the given
|
||||
// content. Upon closing, the contents of the file are read back returned
|
||||
// wrapped in a msgTempfileEditorClosed returned by the tea.Cmd
|
||||
func openTempfileEditor(pattern string, content string, placeholder string) tea.Cmd {
|
||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||
|
||||
err := os.WriteFile(msgFile.Name(), []byte(placeholder+content), os.ModeAppend)
|
||||
if err != nil {
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
editor := os.Getenv("EDITOR")
|
||||
if editor == "" {
|
||||
editor = "vim"
|
||||
}
|
||||
|
||||
c := exec.Command(editor, msgFile.Name())
|
||||
return tea.ExecProcess(c, func(err error) tea.Msg {
|
||||
bytes, err := os.ReadFile(msgFile.Name())
|
||||
if err != nil {
|
||||
return msgError(err)
|
||||
}
|
||||
fileContents := string(bytes)
|
||||
if strings.HasPrefix(fileContents, placeholder) {
|
||||
fileContents = fileContents[len(placeholder):]
|
||||
}
|
||||
stripped := strings.Trim(fileContents, "\n \t")
|
||||
return msgTempfileEditorClosed(stripped)
|
||||
})
|
||||
}
|
||||
60
pkg/util/tty/highlight.go
Normal file
60
pkg/util/tty/highlight.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package tty
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/chroma/v2"
|
||||
"github.com/alecthomas/chroma/v2/formatters"
|
||||
"github.com/alecthomas/chroma/v2/lexers"
|
||||
"github.com/alecthomas/chroma/v2/styles"
|
||||
)
|
||||
|
||||
type ChromaHighlighter struct {
|
||||
lexer chroma.Lexer
|
||||
formatter chroma.Formatter
|
||||
style *chroma.Style
|
||||
}
|
||||
|
||||
func NewChromaHighlighter(lang, format, style string) *ChromaHighlighter {
|
||||
l := lexers.Get(lang)
|
||||
if l == nil {
|
||||
l = lexers.Fallback
|
||||
}
|
||||
l = chroma.Coalesce(l)
|
||||
|
||||
f := formatters.Get(format)
|
||||
if f == nil {
|
||||
f = formatters.Fallback
|
||||
}
|
||||
|
||||
s := styles.Get(style)
|
||||
if s == nil {
|
||||
s = styles.Fallback
|
||||
}
|
||||
|
||||
return &ChromaHighlighter{
|
||||
lexer: l,
|
||||
formatter: f,
|
||||
style: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) Highlight(w io.Writer, text string) error {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.formatter.Format(w, s.style, it)
|
||||
}
|
||||
|
||||
func (s *ChromaHighlighter) HighlightS(text string) (string, error) {
|
||||
it, err := s.lexer.Tokenise(nil, text)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb := strings.Builder{}
|
||||
sb.Grow(len(text) * 2)
|
||||
s.formatter.Format(&sb, s.style, it)
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package cli
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -17,11 +17,11 @@ import (
|
||||
// contents of the file exactly match the value of placeholder (no edits to the
|
||||
// file were made), then an empty string is returned. Otherwise, the contents
|
||||
// are returned. Example patten: message.*.md
|
||||
func InputFromEditor(placeholder string, pattern string) (string, error) {
|
||||
func InputFromEditor(placeholder string, pattern string, content string) (string, error) {
|
||||
msgFile, _ := os.CreateTemp("/tmp", pattern)
|
||||
defer os.Remove(msgFile.Name())
|
||||
|
||||
os.WriteFile(msgFile.Name(), []byte(placeholder), os.ModeAppend)
|
||||
os.WriteFile(msgFile.Name(), []byte(placeholder + content), os.ModeAppend)
|
||||
|
||||
editor := os.Getenv("EDITOR")
|
||||
if editor == "" {
|
||||
@@ -38,7 +38,7 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
|
||||
}
|
||||
|
||||
bytes, _ := os.ReadFile(msgFile.Name())
|
||||
content := string(bytes)
|
||||
content = string(bytes)
|
||||
|
||||
if placeholder != "" {
|
||||
if content == placeholder {
|
||||
@@ -56,7 +56,7 @@ func InputFromEditor(placeholder string, pattern string) (string, error) {
|
||||
|
||||
// humanTimeElapsedSince returns a human-friendly "in the past" representation
|
||||
// of the given duration.
|
||||
func humanTimeElapsedSince(d time.Duration) string {
|
||||
func HumanTimeElapsedSince(d time.Duration) string {
|
||||
seconds := d.Seconds()
|
||||
minutes := seconds / 60
|
||||
hours := minutes / 60
|
||||
@@ -151,6 +151,14 @@ func SetStructDefaults(data interface{}) bool {
|
||||
intValue, _ := strconv.ParseInt(defaultTag, 10, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetInt(intValue)
|
||||
case reflect.Float32:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 32)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Float64:
|
||||
floatValue, _ := strconv.ParseFloat(defaultTag, 64)
|
||||
field.Set(reflect.New(e))
|
||||
field.Elem().SetFloat(floatValue)
|
||||
case reflect.Bool:
|
||||
boolValue := defaultTag == "true"
|
||||
field.Set(reflect.ValueOf(&boolValue))
|
||||
@@ -160,10 +168,8 @@ func SetStructDefaults(data interface{}) bool {
|
||||
return changed
|
||||
}
|
||||
|
||||
// FileContents returns the string contents of the given file.
|
||||
// TODO: we should support retrieving the content (or an approximation of)
|
||||
// non-text documents, e.g. PDFs.
|
||||
func FileContents(file string) (string, error) {
|
||||
// ReadFileContents returns the string contents of the given file.
|
||||
func ReadFileContents(file string) (string, error) {
|
||||
path := filepath.Clean(file)
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
Reference in New Issue
Block a user