diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index eb7ec2e9..8651d406 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -41,6 +41,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/refresh" "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/cmd/reset" + "github.com/brevdev/brev-cli/pkg/cmd/revokessh" "github.com/brevdev/brev-cli/pkg/cmd/runtasks" "github.com/brevdev/brev-cli/pkg/cmd/scale" "github.com/brevdev/brev-cli/pkg/cmd/set" @@ -333,6 +334,7 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(deregister.NewCmdDeregister(t, externalNodeCmdStore)) cmd.AddCommand(enablessh.NewCmdEnableSSH(t, externalNodeCmdStore)) cmd.AddCommand(grantssh.NewCmdGrantSSH(t, externalNodeCmdStore)) + cmd.AddCommand(revokessh.NewCmdRevokeSSH(t, externalNodeCmdStore)) cmd.AddCommand(runtasks.NewCmdRunTasks(t, noLoginCmdStore)) cmd.AddCommand(proxy.NewCmdProxy(t, noLoginCmdStore)) cmd.AddCommand(healthcheck.NewCmdHealthcheck(t, noLoginCmdStore)) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 06b902dd..7113a31d 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -8,11 +8,11 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" + breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" - breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/terminal" @@ -22,7 +22,6 @@ import ( // DeregisterStore defines the store methods needed by the deregister command. type DeregisterStore interface { GetCurrentUser() (*entity.User, error) - GetBrevHomePath() (string, error) GetAccessToken() (string, error) } @@ -53,13 +52,13 @@ type deregisterDeps struct { sshKeys SSHKeyRemover } -func defaultDeregisterDeps(brevHome string) deregisterDeps { +func defaultDeregisterDeps() deregisterDeps { return deregisterDeps{ platform: register.LinuxPlatform{}, prompter: register.TerminalPrompter{}, netbird: register.Netbird{}, nodeClients: register.DefaultNodeClientFactory{}, - registrationStore: register.NewFileRegistrationStore(brevHome), + registrationStore: register.NewFileRegistrationStore(), sshKeys: brevSSHKeyRemover{}, } } @@ -82,11 +81,7 @@ func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Comman Long: deregisterLong, Example: deregisterExample, RunE: func(cmd *cobra.Command, args []string) error { - brevHome, err := store.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return runDeregister(cmd.Context(), t, store, defaultDeregisterDeps(brevHome)) + return runDeregister(cmd.Context(), t, store, defaultDeregisterDeps()) }, } @@ -98,17 +93,9 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, return fmt.Errorf("brev deregister is only supported on Linux") } - registered, err := deps.registrationStore.Exists() - if err != nil { - return breverrors.WrapAndTrace(err) - } - if !registered { - return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device") - } - reg, err := deps.registrationStore.Load() if err != nil { - return fmt.Errorf("failed to read registration file: %w", err) + return breverrors.WrapAndTrace(err) } t.Vprint("") @@ -169,7 +156,7 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, t.Vprint("Removing registration data...") if err := deps.registrationStore.Delete(); err != nil { t.Vprintf(" Warning: failed to remove local registration file: %v\n", err) - t.Vprint(" You can manually remove it with: rm ~/.brev/device_registration.json") + t.Vprint(" You can manually remove it with: rm /etc/brev/device_registration.json") } t.Vprint(t.Green("Deregistration complete.")) diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index b901c9ff..64870429 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -19,7 +19,6 @@ import ( type mockDeregisterStore struct { user *entity.User - home string token string err error } @@ -31,8 +30,7 @@ func (m *mockDeregisterStore) GetCurrentUser() (*entity.User, error) { return m.user, nil } -func (m *mockDeregisterStore) GetBrevHomePath() (string, error) { return m.home, nil } -func (m *mockDeregisterStore) GetAccessToken() (string, error) { return m.token, nil } +func (m *mockDeregisterStore) GetAccessToken() (string, error) { return m.token, nil } // fakeNodeService implements the server side of ExternalNodeService for testing. type fakeNodeService struct { @@ -150,8 +148,8 @@ func Test_runDeregister_HappyPath(t *testing.T) { } store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } @@ -196,8 +194,8 @@ func Test_runDeregister_UserCancels(t *testing.T) { } store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } @@ -229,8 +227,8 @@ func Test_runDeregister_NotRegistered(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } @@ -255,8 +253,8 @@ func Test_runDeregister_RemoveNodeFails(t *testing.T) { } store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } @@ -295,8 +293,8 @@ func Test_runDeregister_AlwaysUninstallsNetbird(t *testing.T) { } store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } @@ -343,8 +341,8 @@ func Test_runDeregister_RemoveBrevKeysHandling(t *testing.T) { } store := &mockDeregisterStore{ - user: &entity.User{ID: "user_1"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + token: "tok", } diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go index dac16bc8..7b8c5f5e 100644 --- a/pkg/cmd/enablessh/enablessh.go +++ b/pkg/cmd/enablessh/enablessh.go @@ -20,7 +20,6 @@ import ( // EnableSSHStore defines the store methods needed by the enableSSH command. type EnableSSHStore interface { GetCurrentUser() (*entity.User, error) - GetBrevHomePath() (string, error) GetAccessToken() (string, error) } @@ -32,11 +31,11 @@ type enableSSHDeps struct { registrationStore register.RegistrationStore } -func defaultEnableSSHDeps(brevHome string) enableSSHDeps { +func defaultEnableSSHDeps() enableSSHDeps { return enableSSHDeps{ platform: register.LinuxPlatform{}, nodeClients: register.DefaultNodeClientFactory{}, - registrationStore: register.NewFileRegistrationStore(brevHome), + registrationStore: register.NewFileRegistrationStore(), } } @@ -49,11 +48,7 @@ func NewCmdEnableSSH(t *terminal.Terminal, store EnableSSHStore) *cobra.Command Long: "Enable SSH access to this registered device for the current Brev user.", Example: " brev enable-ssh", RunE: func(cmd *cobra.Command, args []string) error { - brevHome, err := store.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return runEnableSSH(cmd.Context(), t, store, defaultEnableSSHDeps(brevHome)) + return runEnableSSH(cmd.Context(), t, store, defaultEnableSSHDeps()) }, } @@ -65,14 +60,6 @@ func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, d return fmt.Errorf("brev enable-ssh is only supported on Linux") } - registered, err := deps.registrationStore.Exists() - if err != nil { - return breverrors.WrapAndTrace(err) - } - if !registered { - return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") - } - reg, err := deps.registrationStore.Load() if err != nil { return fmt.Errorf("failed to read registration file: %w", err) diff --git a/pkg/cmd/enablessh/enablessh_test.go b/pkg/cmd/enablessh/enablessh_test.go index 8c26b88f..f62fe90a 100644 --- a/pkg/cmd/enablessh/enablessh_test.go +++ b/pkg/cmd/enablessh/enablessh_test.go @@ -31,23 +31,23 @@ func readAuthorizedKeys(t *testing.T, u *user.User) string { func Test_InstallAuthorizedKey_TagsKeyWithBrevComment(t *testing.T) { u := tempUser(t) - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { - t.Errorf("expected key tagged with %q, got:\n%s", register.BrevKeyComment, content) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { + t.Errorf("expected key tagged with %q, got:\n%s", register.BrevKeyPrefix, content) } } func Test_InstallAuthorizedKey_SkipsDuplicate(t *testing.T) { u := tempUser(t) - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("first install: %v", err) } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("second install: %v", err) } @@ -66,11 +66,11 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { if err := os.MkdirAll(sshDir, 0o700); err != nil { t.Fatal(err) } - if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+register.BrevKeyComment+"\n"), 0o600); err != nil { + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte("ssh-rsa AAAA testkey "+register.BrevKeyPrefix+"\n"), 0o600); err != nil { t.Fatal(err) } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -84,10 +84,10 @@ func Test_InstallAuthorizedKey_SkipsDuplicateEvenIfAlreadyTagged(t *testing.T) { func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { u := tempUser(t) - if _, err := register.InstallAuthorizedKey(u, ""); err != nil { + if _, err := register.InstallAuthorizedKey(u, "", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } - if _, err := register.InstallAuthorizedKey(u, " "); err != nil { + if _, err := register.InstallAuthorizedKey(u, " ", ""); err != nil { t.Fatalf("InstallAuthorizedKey (whitespace): %v", err) } @@ -101,7 +101,7 @@ func Test_InstallAuthorizedKey_EmptyKeyIsNoop(t *testing.T) { func Test_InstallAuthorizedKey_CreatesSSHDir(t *testing.T) { u := tempUser(t) - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -126,7 +126,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { t.Fatal(err) } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } @@ -134,7 +134,7 @@ func Test_InstallAuthorizedKey_PreservesExistingKeys(t *testing.T) { if !strings.Contains(content, "ssh-rsa EXISTING user@host") { t.Errorf("existing key was lost:\n%s", content) } - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { t.Errorf("new key not found:\n%s", content) } } @@ -152,13 +152,13 @@ func Test_InstallAuthorizedKey_TagsExistingUntaggedKey(t *testing.T) { } // InstallAuthorizedKey should tag the existing key rather than adding a duplicate. - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { t.Fatalf("InstallAuthorizedKey: %v", err) } content := readAuthorizedKeys(t, u) - if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyComment) { - t.Errorf("expected existing key to be tagged with %q, got:\n%s", register.BrevKeyComment, content) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+register.BrevKeyPrefix) { + t.Errorf("expected existing key to be tagged with %q, got:\n%s", register.BrevKeyPrefix, content) } count := strings.Count(content, "ssh-rsa AAAA testkey") if count != 1 { @@ -177,9 +177,9 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { content := strings.Join([]string{ "ssh-rsa EXISTING user@host", - "ssh-rsa BREVKEY1 " + register.BrevKeyComment, + "ssh-rsa BREVKEY1 " + register.BrevKeyPrefix, "ssh-ed25519 OTHERKEY admin@server", - "ssh-rsa BREVKEY2 " + register.BrevKeyComment, + "ssh-rsa BREVKEY2 " + register.BrevKeyPrefix, "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { @@ -196,7 +196,7 @@ func Test_RemoveBrevAuthorizedKeys_RemovesTaggedKeys(t *testing.T) { } result := readAuthorizedKeys(t, u) - if strings.Contains(result, register.BrevKeyComment) { + if strings.Contains(result, register.BrevKeyPrefix) { t.Errorf("brev keys still present:\n%s", result) } if !strings.Contains(result, "ssh-rsa EXISTING user@host") { @@ -256,7 +256,7 @@ func Test_RemoveAuthorizedKey_RemovesOnlyTargetKey(t *testing.T) { content := strings.Join([]string{ "ssh-rsa KEEP1 user@host", - "ssh-rsa TARGET " + register.BrevKeyComment, + "ssh-rsa TARGET " + register.BrevKeyPrefix, "ssh-rsa KEEP2 admin@server", "", }, "\n") @@ -329,8 +329,8 @@ func Test_RemoveAuthorizedKey_DoesNotRemoveOtherBrevKeys(t *testing.T) { } content := strings.Join([]string{ - "ssh-rsa ALICE_KEY " + register.BrevKeyComment, - "ssh-rsa BOB_KEY " + register.BrevKeyComment, + "ssh-rsa ALICE_KEY " + register.BrevKeyPrefix, + "ssh-rsa BOB_KEY " + register.BrevKeyPrefix, "", }, "\n") if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { @@ -366,10 +366,10 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } // Install two brev keys. - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY1", "user_1"); err != nil { t.Fatal(err) } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa KEY2", "user_2"); err != nil { t.Fatal(err) } @@ -379,7 +379,7 @@ func Test_InstallThenRemove_RoundTrip(t *testing.T) { } result := readAuthorizedKeys(t, u) - if strings.Contains(result, register.BrevKeyComment) { + if strings.Contains(result, register.BrevKeyPrefix) { t.Errorf("brev keys still present after removal:\n%s", result) } if !strings.Contains(result, "ssh-rsa EXISTING user@host") { @@ -393,10 +393,10 @@ func Test_InstallThenRemoveSpecificKey_RollbackScenario(t *testing.T) { u := tempUser(t) // Install two brev keys (simulating two users granted access). - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa ALICE", "user_a"); err != nil { t.Fatal(err) } - if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB"); err != nil { + if _, err := register.InstallAuthorizedKey(u, "ssh-rsa BOB", "user_b"); err != nil { t.Fatal(err) } diff --git a/pkg/cmd/grantssh/grantssh.go b/pkg/cmd/grantssh/grantssh.go index 72543e19..d4c9b6b0 100644 --- a/pkg/cmd/grantssh/grantssh.go +++ b/pkg/cmd/grantssh/grantssh.go @@ -43,12 +43,12 @@ type resolvedMember struct { attachment entity.OrgRoleAttachment } -func defaultGrantSSHDeps(brevHome string) grantSSHDeps { +func defaultGrantSSHDeps() grantSSHDeps { return grantSSHDeps{ platform: register.LinuxPlatform{}, prompter: register.TerminalPrompter{}, nodeClients: register.DefaultNodeClientFactory{}, - registrationStore: register.NewFileRegistrationStore(brevHome), + registrationStore: register.NewFileRegistrationStore(), } } @@ -61,11 +61,7 @@ func NewCmdGrantSSH(t *terminal.Terminal, store GrantSSHStore) *cobra.Command { Long: "Grant SSH access to this registered device for another member of your organization.", Example: " brev grant-ssh", RunE: func(cmd *cobra.Command, args []string) error { - brevHome, err := store.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return runGrantSSH(cmd.Context(), t, store, defaultGrantSSHDeps(brevHome)) + return runGrantSSH(cmd.Context(), t, store, defaultGrantSSHDeps()) }, } @@ -79,7 +75,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep removeCredentialsFile(t, s) - reg, err := getRegistration(deps) + reg, err := deps.registrationStore.Load() if err != nil { return breverrors.WrapAndTrace(err) } @@ -148,7 +144,7 @@ func runGrantSSH(ctx context.Context, t *terminal.Terminal, s GrantSSHStore, dep func checkSSHEnabled(currentUserPubKey string) error { currentUserPubKey = strings.TrimSpace(currentUserPubKey) if currentUserPubKey == "" { - return fmt.Errorf("curren user does not have a Brev public key") + return fmt.Errorf("current user does not have a Brev public key") } u, err := user.Current() @@ -169,22 +165,6 @@ func checkSSHEnabled(currentUserPubKey string) error { return nil } -func getRegistration(deps grantSSHDeps) (*register.DeviceRegistration, error) { - registered, err := deps.registrationStore.Exists() - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - if !registered { - return nil, fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") - } - - reg, err := deps.registrationStore.Load() - if err != nil { - return nil, fmt.Errorf("failed to read registration file: %w", err) - } - return reg, nil -} - func getOrgMembers(currentUser *entity.User, t *terminal.Terminal, s GrantSSHStore, orgId string) ([]resolvedMember, error) { attachments, err := s.GetOrgRoleAttachments(orgId) if err != nil { diff --git a/pkg/cmd/register/device_registration_store.go b/pkg/cmd/register/device_registration_store.go index 0161ce76..102fe92c 100644 --- a/pkg/cmd/register/device_registration_store.go +++ b/pkg/cmd/register/device_registration_store.go @@ -1,8 +1,11 @@ package register import ( + "bytes" "encoding/json" + "fmt" "os" + "os/exec" "path/filepath" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -10,7 +13,10 @@ import ( "github.com/spf13/afero" ) -const registrationFileName = "device_registration.json" +const ( + registrationFileName = "device_registration.json" + globalRegistrationDir = "/etc/brev" +) // DeviceRegistration is the persistent identity file for a registered device. // Fields align with the AddNodeResponse from dev-plane. @@ -31,18 +37,17 @@ type RegistrationStore interface { Exists() (bool, error) } -// FileRegistrationStore implements RegistrationStore using the local filesystem. -type FileRegistrationStore struct { - brevHome string -} +// FileRegistrationStore implements RegistrationStore using the global /etc/brev/ path. +type FileRegistrationStore struct{} -// NewFileRegistrationStore returns a FileRegistrationStore rooted at brevHome. -func NewFileRegistrationStore(brevHome string) *FileRegistrationStore { - return &FileRegistrationStore{brevHome: brevHome} +// NewFileRegistrationStore returns a FileRegistrationStore that reads/writes +// from /etc/brev/device_registration.json. +func NewFileRegistrationStore() *FileRegistrationStore { + return &FileRegistrationStore{} } func (s *FileRegistrationStore) path() string { - return filepath.Join(s.brevHome, registrationFileName) + return filepath.Join(globalRegistrationDir, registrationFileName) } func (s *FileRegistrationStore) Save(reg *DeviceRegistration) error { @@ -51,20 +56,31 @@ func (s *FileRegistrationStore) Save(reg *DeviceRegistration) error { if err != nil { return breverrors.WrapAndTrace(err) } - if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o700); err != nil { - return breverrors.WrapAndTrace(err) - } - if err := afero.WriteFile(files.AppFs, path, data, 0o600); err != nil { - return breverrors.WrapAndTrace(err) + + // Try direct write first (works in tests with in-memory FS and when running as root). + mkdirErr := files.AppFs.MkdirAll(filepath.Dir(path), 0o755) + if mkdirErr == nil { + if writeErr := afero.WriteFile(files.AppFs, path, data, 0o644); writeErr == nil { + return nil + } } - return nil + + // Fall back to sudo for non-root users writing to /etc/brev/. + return sudoWriteFile(path, data) } +// Load reads the registration file and returns the parsed DeviceRegistration func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { path := s.path() + exists, err := s.Exists() + if !exists { + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return nil, breverrors.WrapAndTrace(breverrors.New("device registration not found, run 'brev register' first")) + } var reg DeviceRegistration - err := files.ReadJSON(files.AppFs, path, ®) - if err != nil { + if err := files.ReadJSON(files.AppFs, path, ®); err != nil { return nil, breverrors.WrapAndTrace(err) } return ®, nil @@ -73,10 +89,14 @@ func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { func (s *FileRegistrationStore) Delete() error { path := s.path() err := files.DeleteFile(files.AppFs, path) - if err != nil { + if err == nil { + return nil + } + if !os.IsPermission(err) { return breverrors.WrapAndTrace(err) } - return nil + // Fall back to sudo for non-root users. + return sudoDeleteFile(path) } func (s *FileRegistrationStore) Exists() (bool, error) { @@ -90,3 +110,26 @@ func (s *FileRegistrationStore) Exists() (bool, error) { } return false, breverrors.WrapAndTrace(err) } + +// sudoWriteFile creates the parent directory and writes data to path using sudo. +func sudoWriteFile(path string, data []byte) error { + dir := filepath.Dir(path) + if err := exec.Command("sudo", "mkdir", "-p", dir).Run(); err != nil { //nolint:gosec // fixed base path + return fmt.Errorf("sudo mkdir %s failed: %w", dir, err) + } + cmd := exec.Command("sudo", "tee", path) //nolint:gosec // fixed base path + cmd.Stdin = bytes.NewReader(data) + cmd.Stdout = nil // suppress tee's stdout echo + if err := cmd.Run(); err != nil { + return fmt.Errorf("sudo tee %s failed: %w", path, err) + } + return nil +} + +// sudoDeleteFile removes a file using sudo. +func sudoDeleteFile(path string) error { + if err := exec.Command("sudo", "rm", "-f", path).Run(); err != nil { //nolint:gosec // fixed base path + return fmt.Errorf("sudo rm %s failed: %w", path, err) + } + return nil +} diff --git a/pkg/cmd/register/device_registration_store_test.go b/pkg/cmd/register/device_registration_store_test.go index 6e14514d..256acd4a 100644 --- a/pkg/cmd/register/device_registration_store_test.go +++ b/pkg/cmd/register/device_registration_store_test.go @@ -7,22 +7,21 @@ import ( "github.com/spf13/afero" ) -func setupTestFs(t *testing.T) (string, func()) { +func setupTestFs(t *testing.T) func() { t.Helper() origFs := files.AppFs files.AppFs = afero.NewMemMapFs() - brevHome := "/home/testuser/.brev" - if err := files.AppFs.MkdirAll(brevHome, 0o770); err != nil { + if err := files.AppFs.MkdirAll(globalRegistrationDir, 0o755); err != nil { t.Fatalf("failed to create test dir: %v", err) } - return brevHome, func() { files.AppFs = origFs } + return func() { files.AppFs = origFs } } func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() cpuCount := int32(12) ramBytes := int64(137438953472) @@ -69,10 +68,10 @@ func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { } func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() exists, err := store.Exists() if err != nil { @@ -84,10 +83,10 @@ func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { } func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() reg := &DeviceRegistration{ ExternalNodeID: "unode_abc123", @@ -107,10 +106,10 @@ func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { } func Test_DeleteRegistration_RemovesFile(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() reg := &DeviceRegistration{ ExternalNodeID: "unode_abc123", @@ -134,10 +133,10 @@ func Test_DeleteRegistration_RemovesFile(t *testing.T) { } func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() _, err := store.Load() if err == nil { @@ -146,10 +145,10 @@ func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { } func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { - brevHome, cleanup := setupTestFs(t) + cleanup := setupTestFs(t) defer cleanup() - store := NewFileRegistrationStore(brevHome) + store := NewFileRegistrationStore() err := store.Delete() if err == nil { diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go index 866231bf..d23ec170 100644 --- a/pkg/cmd/register/hardware.go +++ b/pkg/cmd/register/hardware.go @@ -167,11 +167,7 @@ func parseOSReleaseContent(content string) (string, string) { // unquote removes surrounding double quotes from a string. func unquote(s string) string { - s = strings.TrimSpace(s) - if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { - return s[1 : len(s)-1] - } - return s + return strings.Trim(strings.TrimSpace(s), "\"") } // parseNvidiaSMI queries nvidia-smi for GPU information. diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 5049ff3f..63cc9550 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -27,7 +27,6 @@ import ( type RegisterStore interface { GetCurrentUser() (*entity.User, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) - GetBrevHomePath() (string, error) GetAccessToken() (string, error) } @@ -66,7 +65,7 @@ type registerDeps struct { registrationStore RegistrationStore } -func defaultRegisterDeps(brevHome string) registerDeps { +func defaultRegisterDeps() registerDeps { return registerDeps{ platform: LinuxPlatform{}, prompter: TerminalPrompter{}, @@ -75,7 +74,7 @@ func defaultRegisterDeps(brevHome string) registerDeps { nodeClients: DefaultNodeClientFactory{}, commandRunner: ExecCommandRunner{}, fileReader: OSFileReader{}, - registrationStore: NewFileRegistrationStore(brevHome), + registrationStore: NewFileRegistrationStore(), } } @@ -90,18 +89,18 @@ This command sets up network connectivity and registers this machine with Brev.` func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, - Use: "register ", + Use: "register [name]", DisableFlagsInUseLine: true, Short: "Register this device with Brev", Long: registerLong, Example: registerExample, - Args: cobra.ExactArgs(1), + Args: cobra.MaximumNArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - brevHome, err := store.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) + var name string + if len(args) > 0 { + name = args[0] } - return runRegister(cmd.Context(), t, store, args[0], defaultRegisterDeps(brevHome)) + return runRegister(cmd.Context(), t, store, name, defaultRegisterDeps()) }, } @@ -118,11 +117,11 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return breverrors.WrapAndTrace(err) } if alreadyRegistered { - reg, loadErr := deps.registrationStore.Load() - if loadErr != nil { - return fmt.Errorf("this machine is already registered but the registration file could not be read: %w", loadErr) - } - return checkExistingRegistration(ctx, t, s, deps, reg) + return checkExistingRegistration(ctx, t, s, name, deps) + } + + if name == "" { + return fmt.Errorf("please provide a name for this device\n\nUsage: brev register \nExample: brev register \"my-DGX-Spark\"") } brevUser, err := s.GetCurrentUser() @@ -232,7 +231,20 @@ func getOrgToRegisterFor(s RegisterStore) (*entity.Organization, error) { // It calls GetNode to check the server-side NetworkMemberStatus and ensures the // local netbird service is running, starting it if necessary. Returns nil if // the node is healthy, or an error describing what's wrong. -func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s RegisterStore, deps registerDeps, reg *DeviceRegistration) error { +func checkExistingRegistration(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { + reg, loadErr := deps.registrationStore.Load() + if loadErr != nil { + return fmt.Errorf("this machine is already registered but the registration file could not be read: %w", loadErr) + } + if name != "" && name != reg.DisplayName { + // TODO maybe allow for a name change + t.Vprintf("This machine is already registered as %q.\n", reg.DisplayName) + t.Vprint("Run 'brev deregister' first if you want to re-register with a different name.") + t.Vprint("") + t.Vprintf("If you are having tunnel issues, run 'brev register %q' to reconnect.", reg.DisplayName) + return nil + } + t.Vprint("") t.Vprintf(" This machine is already registered as %s (%s).\n", reg.DisplayName, reg.ExternalNodeID) t.Vprint(" Checking connectivity...") diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 450ae720..8eb25776 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -20,7 +20,6 @@ import ( type mockRegisterStore struct { user *entity.User org *entity.Organization - home string token string err error } @@ -36,8 +35,7 @@ func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organizati return m.org, nil } -func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } -func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } +func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } // mockRegistrationStore satisfies RegistrationStore for orchestration tests. type mockRegistrationStore struct { @@ -135,9 +133,9 @@ func Test_runRegister_HappyPath(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", } @@ -209,9 +207,9 @@ func Test_runRegister_UserCancels(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", } @@ -297,9 +295,9 @@ func Test_runRegister_AlreadyRegistered(t *testing.T) { } store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", } @@ -308,7 +306,9 @@ func Test_runRegister_AlreadyRegistered(t *testing.T) { defer server.Close() term := terminal.New() - err := runRegister(context.Background(), term, store, "My Spark", deps) + // Pass the same name as the existing registration so we go through + // the checkExistingRegistration path (not the different-name path). + err := runRegister(context.Background(), term, store, "Existing", deps) if err != nil { t.Fatalf("expected nil error, got: %v", err) } @@ -325,9 +325,9 @@ func Test_runRegister_NoOrganization(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: nil, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: nil, + token: "tok", } @@ -346,9 +346,9 @@ func Test_runRegister_AddNodeFails(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", } @@ -381,9 +381,9 @@ func Test_runRegister_NoSetupCommand(t *testing.T) { regStore := &mockRegistrationStore{} store := &mockRegisterStore{ - user: &entity.User{ID: "user_1"}, - org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", } @@ -506,7 +506,6 @@ func Test_runRegister_GrantSSH_retries_on_connection_error_then_succeeds(t *test store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", token: "tok", } @@ -556,7 +555,6 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: "/home/testuser/.brev", token: "tok", } @@ -596,3 +594,70 @@ func Test_runRegister_GrantSSH_no_retry_on_permanent_error(t *testing.T) { t.Errorf("expected GrantNodeSSHAccess to be called once (no retry on permanent error), got %d", grantCalls) } } + +func Test_runRegister_NoNameNotRegistered(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "", deps) + if err == nil { + t.Fatal("expected error when no name provided and not registered") + } + if !strings.Contains(err.Error(), "please provide a name") { + t.Errorf("expected 'please provide a name' error, got: %v", err) + } +} + +func Test_runRegister_NoNameAlreadyRegistered(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing Device", + OrgID: "org_123", + }, + } + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + token: "tok", + } + + svc := &fakeNodeService{ + getNodeFn: func(req *nodev1.GetNodeRequest) (*nodev1.GetNodeResponse, error) { + return &nodev1.GetNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: req.GetExternalNodeId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + Status: nodev1.NetworkMemberStatus_NETWORK_MEMBER_STATUS_CONNECTED, + }, + }, + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "", deps) + if err != nil { + t.Fatalf("expected nil error when already registered with no name, got: %v", err) + } + + // Registration should still exist + exists, _ := regStore.Exists() + if !exists { + t.Error("expected registration to still exist") + } +} diff --git a/pkg/cmd/register/sshkeys.go b/pkg/cmd/register/sshkeys.go index aad47e95..4414cad9 100644 --- a/pkg/cmd/register/sshkeys.go +++ b/pkg/cmd/register/sshkeys.go @@ -28,9 +28,107 @@ const ( backoffPrintRound = 500 * time.Millisecond ) -// BrevKeyComment is the marker appended to every SSH key that Brev installs. -// It allows RemoveBrevAuthorizedKeys to identify and remove exactly those keys. -const BrevKeyComment = "# brev-cli" +// BrevKeyPrefix is the marker prefix appended to every SSH key that Brev +// installs. Both old-format ("# brev-cli") and new-format +// ("# brev-cli user_id=xxx") keys start with this prefix. +const BrevKeyPrefix = "# brev-cli" + +// BrevKeyTag returns a comment tag that encodes the Brev user ID. +// Example: "# brev-cli user_id=user_abc123" +func BrevKeyTag(userID string) string { + if userID == "" { + return BrevKeyPrefix + } + return fmt.Sprintf("%s user_id=%s", BrevKeyPrefix, userID) +} + +// BrevAuthorizedKey represents a single Brev-managed key found in +// authorized_keys. +type BrevAuthorizedKey struct { + Line string // full line from authorized_keys + KeyContent string // the ssh key portion (without the brev comment) + UserID string // parsed from "user_id=xxx", empty for old-format keys +} + +// ListBrevAuthorizedKeys reads ~/.ssh/authorized_keys and returns all lines +// containing the BrevKeyPrefix marker. +func ListBrevAuthorizedKeys(u *user.User) ([]BrevAuthorizedKey, error) { + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + + data, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("reading authorized_keys: %w", err) + } + + var keys []BrevAuthorizedKey + for _, line := range strings.Split(string(data), "\n") { + if !strings.Contains(line, BrevKeyPrefix) { + continue + } + trimmed := strings.TrimSpace(line) + if trimmed == "" { + continue + } + + bk := BrevAuthorizedKey{Line: trimmed} + + // Split on " # brev-cli" to get the key content before the tag. + if idx := strings.Index(trimmed, " "+BrevKeyPrefix); idx >= 0 { + bk.KeyContent = trimmed[:idx] + tag := trimmed[idx+1:] // the "# brev-cli ..." part + // Parse user_id if present. + if uidIdx := strings.Index(tag, "user_id="); uidIdx >= 0 { + rest := tag[uidIdx+len("user_id="):] + // user_id value ends at next space or end of string. + if spIdx := strings.Index(rest, " "); spIdx >= 0 { + bk.UserID = rest[:spIdx] + } else { + bk.UserID = rest + } + } + } else { + bk.KeyContent = trimmed + } + + keys = append(keys, bk) + } + + return keys, nil +} + +// RemoveAuthorizedKeyLine removes an exact line from authorized_keys. +func RemoveAuthorizedKeyLine(u *user.User, line string) error { + line = strings.TrimSpace(line) + if line == "" { + return nil + } + + authKeysPath := filepath.Join(u.HomeDir, ".ssh", "authorized_keys") + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("reading authorized_keys: %w", err) + } + + var kept []string + for _, l := range strings.Split(string(existing), "\n") { + if strings.TrimSpace(l) == line { + continue + } + kept = append(kept, l) + } + + result := strings.Join(kept, "\n") + if err := os.WriteFile(authKeysPath, []byte(result), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + return nil +} // GrantSSHAccessToNode installs the user's public key in authorized_keys and // calls GrantNodeSSHAccess to record access server-side. If the RPC fails, @@ -45,7 +143,7 @@ func GrantSSHAccessToNode( osUser *user.User, ) error { if targetUser.PublicKey != "" { - if added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey); err != nil { + if added, err := InstallAuthorizedKey(osUser, targetUser.PublicKey, targetUser.ID); err != nil { t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) } else if added { t.Vprint(" Brev public key added to authorized_keys.") @@ -96,9 +194,10 @@ func GrantSSHAccessToNode( // InstallAuthorizedKey appends the given public key to the user's // ~/.ssh/authorized_keys if it isn't already present. The key is tagged with -// a brev-cli comment so it can be removed later by RemoveBrevAuthorizedKeys. +// a brev-cli comment (including the user ID) so it can be identified and +// removed later by RemoveBrevAuthorizedKeys or ListBrevAuthorizedKeys. // Returns true if the key was newly written, false if it was already present. -func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { +func InstallAuthorizedKey(u *user.User, pubKey string, brevUserID string) (bool, error) { pubKey = strings.TrimSpace(pubKey) if pubKey == "" { return false, nil @@ -116,7 +215,7 @@ func InstallAuthorizedKey(u *user.User, pubKey string) (bool, error) { return false, fmt.Errorf("reading authorized_keys: %w", err) } - taggedKey := pubKey + " " + BrevKeyComment + taggedKey := pubKey + " " + BrevKeyTag(brevUserID) if strings.Contains(string(existing), taggedKey) { return false, nil // already present with tag @@ -197,7 +296,7 @@ func RemoveBrevAuthorizedKeys(u *user.User) ([]string, error) { var kept []string var removed []string for _, line := range strings.Split(string(existing), "\n") { - if strings.Contains(line, BrevKeyComment) { + if strings.Contains(line, BrevKeyPrefix) { if trimmed := strings.TrimSpace(line); trimmed != "" { removed = append(removed, trimmed) } diff --git a/pkg/cmd/register/sshkeys_test.go b/pkg/cmd/register/sshkeys_test.go new file mode 100644 index 00000000..4537d847 --- /dev/null +++ b/pkg/cmd/register/sshkeys_test.go @@ -0,0 +1,232 @@ +package register + +import ( + "os" + "os/user" + "path/filepath" + "strings" + "testing" +) + +func tempUser(t *testing.T) *user.User { + t.Helper() + return &user.User{HomeDir: t.TempDir()} +} + +func seedKeys(t *testing.T, u *user.User, content string) { + t.Helper() + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } +} + +func readKeys(t *testing.T, u *user.User) string { + t.Helper() + data, err := os.ReadFile(filepath.Join(u.HomeDir, ".ssh", "authorized_keys")) + if err != nil { + t.Fatalf("reading authorized_keys: %v", err) + } + return string(data) +} + +// --- BrevKeyTag --- + +func TestBrevKeyTag_WithUserID(t *testing.T) { + tag := BrevKeyTag("user_abc123") + expected := "# brev-cli user_id=user_abc123" + if tag != expected { + t.Errorf("expected %q, got %q", expected, tag) + } +} + +func TestBrevKeyTag_EmptyUserID(t *testing.T) { + tag := BrevKeyTag("") + if tag != BrevKeyPrefix { + t.Errorf("expected %q, got %q", BrevKeyPrefix, tag) + } +} + +// --- ListBrevAuthorizedKeys --- + +func TestListBrevAuthorizedKeys_ParsesNewFormat(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, strings.Join([]string{ + "ssh-rsa EXISTING user@host", + "ssh-ed25519 AAAA_ALICE # brev-cli user_id=user_1", + "ssh-rsa AAAA_BOB # brev-cli user_id=user_2", + "", + }, "\n")) + + keys, err := ListBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("ListBrevAuthorizedKeys: %v", err) + } + + if len(keys) != 2 { + t.Fatalf("expected 2 keys, got %d", len(keys)) + } + + if keys[0].KeyContent != "ssh-ed25519 AAAA_ALICE" { + t.Errorf("expected key content 'ssh-ed25519 AAAA_ALICE', got %q", keys[0].KeyContent) + } + if keys[0].UserID != "user_1" { + t.Errorf("expected user_id 'user_1', got %q", keys[0].UserID) + } + + if keys[1].KeyContent != "ssh-rsa AAAA_BOB" { + t.Errorf("expected key content 'ssh-rsa AAAA_BOB', got %q", keys[1].KeyContent) + } + if keys[1].UserID != "user_2" { + t.Errorf("expected user_id 'user_2', got %q", keys[1].UserID) + } +} + +func TestListBrevAuthorizedKeys_ParsesOldFormat(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, "ssh-ed25519 AAAA_OLD # brev-cli\n") + + keys, err := ListBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("ListBrevAuthorizedKeys: %v", err) + } + + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].KeyContent != "ssh-ed25519 AAAA_OLD" { + t.Errorf("expected key content 'ssh-ed25519 AAAA_OLD', got %q", keys[0].KeyContent) + } + if keys[0].UserID != "" { + t.Errorf("expected empty user_id for old format, got %q", keys[0].UserID) + } +} + +func TestListBrevAuthorizedKeys_MixedFormats(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, strings.Join([]string{ + "ssh-rsa AAAA_OLD # brev-cli", + "ssh-rsa NONBREV user@host", + "ssh-ed25519 AAAA_NEW # brev-cli user_id=uid_42", + "", + }, "\n")) + + keys, err := ListBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("ListBrevAuthorizedKeys: %v", err) + } + + if len(keys) != 2 { + t.Fatalf("expected 2 brev keys, got %d", len(keys)) + } + + // Old format + if keys[0].UserID != "" { + t.Errorf("expected empty user_id for old format, got %q", keys[0].UserID) + } + // New format + if keys[1].UserID != "uid_42" { + t.Errorf("expected user_id 'uid_42', got %q", keys[1].UserID) + } +} + +func TestListBrevAuthorizedKeys_NoFile(t *testing.T) { + u := tempUser(t) + + keys, err := ListBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("expected no error for missing file, got: %v", err) + } + if len(keys) != 0 { + t.Errorf("expected 0 keys, got %d", len(keys)) + } +} + +func TestListBrevAuthorizedKeys_NoBrevKeys(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, "ssh-rsa NONBREV user@host\n") + + keys, err := ListBrevAuthorizedKeys(u) + if err != nil { + t.Fatalf("ListBrevAuthorizedKeys: %v", err) + } + if len(keys) != 0 { + t.Errorf("expected 0 brev keys, got %d", len(keys)) + } +} + +// --- RemoveAuthorizedKeyLine --- + +func TestRemoveAuthorizedKeyLine_RemovesExactLine(t *testing.T) { + u := tempUser(t) + seedKeys(t, u, strings.Join([]string{ + "ssh-rsa KEEP user@host", + "ssh-ed25519 REMOVE # brev-cli user_id=user_1", + "ssh-rsa KEEP2 admin@server", + "", + }, "\n")) + + if err := RemoveAuthorizedKeyLine(u, "ssh-ed25519 REMOVE # brev-cli user_id=user_1"); err != nil { + t.Fatalf("RemoveAuthorizedKeyLine: %v", err) + } + + result := readKeys(t, u) + if strings.Contains(result, "REMOVE") { + t.Errorf("line was not removed:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa KEEP user@host") { + t.Errorf("other key was removed:\n%s", result) + } + if !strings.Contains(result, "ssh-rsa KEEP2 admin@server") { + t.Errorf("other key was removed:\n%s", result) + } +} + +func TestRemoveAuthorizedKeyLine_NoopForEmptyLine(t *testing.T) { + u := tempUser(t) + if err := RemoveAuthorizedKeyLine(u, ""); err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} + +func TestRemoveAuthorizedKeyLine_NoopForMissingFile(t *testing.T) { + u := tempUser(t) + if err := RemoveAuthorizedKeyLine(u, "ssh-rsa SOMETHING"); err != nil { + t.Fatalf("expected no error, got: %v", err) + } +} + +// --- InstallAuthorizedKey with user ID --- + +func TestInstallAuthorizedKey_IncludesUserID(t *testing.T) { + u := tempUser(t) + + if _, err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", "user_abc"); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readKeys(t, u) + expected := "ssh-rsa AAAA testkey # brev-cli user_id=user_abc" + if !strings.Contains(content, expected) { + t.Errorf("expected %q in authorized_keys, got:\n%s", expected, content) + } +} + +func TestInstallAuthorizedKey_EmptyUserID_UsesPrefix(t *testing.T) { + u := tempUser(t) + + if _, err := InstallAuthorizedKey(u, "ssh-rsa AAAA testkey", ""); err != nil { + t.Fatalf("InstallAuthorizedKey: %v", err) + } + + content := readKeys(t, u) + if !strings.Contains(content, "ssh-rsa AAAA testkey "+BrevKeyPrefix) { + t.Errorf("expected key tagged with prefix, got:\n%s", content) + } + if strings.Contains(content, "user_id=") { + t.Errorf("should not contain user_id when empty, got:\n%s", content) + } +} diff --git a/pkg/cmd/revokessh/revokessh.go b/pkg/cmd/revokessh/revokessh.go new file mode 100644 index 00000000..881a0d5f --- /dev/null +++ b/pkg/cmd/revokessh/revokessh.go @@ -0,0 +1,162 @@ +// Package revokessh provides the brev revoke-ssh command for revoking SSH access +// to a registered device by removing a Brev-managed key from authorized_keys. +package revokessh + +import ( + "context" + "fmt" + "os/user" + + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/externalnode" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// RevokeSSHStore defines the store methods needed by the revoke-ssh command. +type RevokeSSHStore interface { + GetAccessToken() (string, error) +} + +// revokeSSHDeps bundles the side-effecting dependencies of runRevokeSSH so they +// can be replaced in tests. +type revokeSSHDeps struct { + platform externalnode.PlatformChecker + prompter terminal.Selector + nodeClients externalnode.NodeClientFactory + registrationStore register.RegistrationStore + currentUser func() (*user.User, error) + listBrevKeys func(u *user.User) ([]register.BrevAuthorizedKey, error) + removeKeyLine func(u *user.User, line string) error +} + +func defaultRevokeSSHDeps() revokeSSHDeps { + return revokeSSHDeps{ + platform: register.LinuxPlatform{}, + prompter: register.TerminalPrompter{}, + nodeClients: register.DefaultNodeClientFactory{}, + registrationStore: register.NewFileRegistrationStore(), + currentUser: user.Current, + listBrevKeys: register.ListBrevAuthorizedKeys, + removeKeyLine: register.RemoveAuthorizedKeyLine, + } +} + +func NewCmdRevokeSSH(t *terminal.Terminal, store RevokeSSHStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "revoke-ssh", + DisableFlagsInUseLine: true, + Short: "Revoke SSH access to this device for an org member", + Long: "Revoke SSH access to this registered device for another member of your organization.", + Example: " brev revoke-ssh", + RunE: func(cmd *cobra.Command, args []string) error { + return runRevokeSSH(cmd.Context(), t, store, defaultRevokeSSHDeps()) + }, + } + + return cmd +} + +func runRevokeSSH(ctx context.Context, t *terminal.Terminal, s RevokeSSHStore, deps revokeSSHDeps) error { + if !deps.platform.IsCompatible() { + return fmt.Errorf("brev revoke-ssh is only supported on Linux") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + osUser, err := deps.currentUser() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + + brevKeys, err := deps.listBrevKeys(osUser) + if err != nil { + return fmt.Errorf("failed to read authorized keys: %w", err) + } + + if len(brevKeys) == 0 { + t.Vprint("No Brev SSH keys found to revoke.") + return nil + } + + // Build selector labels from the installed keys. + labels := make([]string, len(brevKeys)) + for i, bk := range brevKeys { + keyPreview := truncateKey(bk.KeyContent, 60) + if bk.UserID != "" { + labels[i] = fmt.Sprintf("%s (user: %s)", keyPreview, bk.UserID) + } else { + labels[i] = keyPreview + } + } + + selected := deps.prompter.Select("Select a Brev SSH key to revoke:", labels) + + selectedIdx := -1 + for i, label := range labels { + if label == selected { + selectedIdx = i + break + } + } + if selectedIdx < 0 { + return fmt.Errorf("selected item %q did not match any key", selected) + } + + selectedKey := brevKeys[selectedIdx] + + t.Vprint("") + t.Vprint(t.Green("Revoking SSH access")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Key: %s\n", truncateKey(selectedKey.KeyContent, 80)) + if selectedKey.UserID != "" { + t.Vprintf(" User: %s\n", selectedKey.UserID) + } + t.Vprint("") + + // Remove the key from authorized_keys first. + if err := deps.removeKeyLine(osUser, selectedKey.Line); err != nil { + return fmt.Errorf("failed to remove key from authorized_keys: %w", err) + } + t.Vprint(" Brev public key removed from authorized_keys.") + + // If we know the user ID, also revoke server-side. + if selectedKey.UserID != "" { + client := deps.nodeClients.NewNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + _, err := client.RevokeNodeSSHAccess(ctx, connect.NewRequest(&nodev1.RevokeNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: selectedKey.UserID, + LinuxUser: osUser.Username, + })) + if err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: server-side revocation failed: %v", err))) + t.Vprint(" The key was removed locally but the server may still show access.") + } + } else { + t.Vprint(" Key was old-format (no user ID); skipping server-side revocation.") + } + + t.Vprint("") + t.Vprint(t.Green("SSH key revoked.")) + return nil +} + +// truncateKey shortens a key string for display, showing the first maxLen +// characters followed by "..." if it's longer. +func truncateKey(key string, maxLen int) string { + if len(key) <= maxLen { + return key + } + return key[:maxLen] + "..." +} diff --git a/pkg/cmd/revokessh/revokessh_test.go b/pkg/cmd/revokessh/revokessh_test.go new file mode 100644 index 00000000..9a0c5b19 --- /dev/null +++ b/pkg/cmd/revokessh/revokessh_test.go @@ -0,0 +1,353 @@ +package revokessh + +import ( + "context" + "fmt" + "net/http/httptest" + "os" + "os/user" + "path/filepath" + "strings" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/externalnode" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// --- mock types --- + +type mockPlatform struct{ compatible bool } + +func (m mockPlatform) IsCompatible() bool { return m.compatible } + +type mockSelector struct { + fn func(label string, items []string) string +} + +func (m mockSelector) Select(label string, items []string) string { + return m.fn(label, items) +} + +type mockNodeClientFactory struct { + serverURL string +} + +func (m mockNodeClientFactory) NewNodeClient(provider externalnode.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return register.NewNodeServiceClient(provider, m.serverURL) +} + +type mockRegistrationStore struct { + reg *register.DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *register.DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*register.DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +type mockRevokeSSHStore struct { + token string +} + +func (m *mockRevokeSSHStore) GetAccessToken() (string, error) { return m.token, nil } + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + revokeSSHFn func(*nodev1.RevokeNodeSSHAccessRequest) (*nodev1.RevokeNodeSSHAccessResponse, error) +} + +func (f *fakeNodeService) RevokeNodeSSHAccess(_ context.Context, req *connect.Request[nodev1.RevokeNodeSSHAccessRequest]) (*connect.Response[nodev1.RevokeNodeSSHAccessResponse], error) { + resp, err := f.revokeSSHFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +// --- helpers --- + +func tempUser(t *testing.T) *user.User { + t.Helper() + return &user.User{HomeDir: t.TempDir(), Username: "testuser"} +} + +func seedAuthorizedKeys(t *testing.T, u *user.User, content string) { + t.Helper() + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(sshDir, "authorized_keys"), []byte(content), 0o600); err != nil { + t.Fatal(err) + } +} + +func readAuthorizedKeys(t *testing.T, u *user.User) string { + t.Helper() + data, err := os.ReadFile(filepath.Join(u.HomeDir, ".ssh", "authorized_keys")) + if err != nil { + t.Fatalf("reading authorized_keys: %v", err) + } + return string(data) +} + +func baseDeps(t *testing.T, svc *fakeNodeService, regStore register.RegistrationStore, osUser *user.User) (revokeSSHDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return revokeSSHDeps{ + platform: mockPlatform{compatible: true}, + prompter: mockSelector{fn: func(_ string, items []string) string { + if len(items) > 0 { + return items[0] + } + return "" + }}, + nodeClients: mockNodeClientFactory{serverURL: server.URL}, + registrationStore: regStore, + currentUser: func() (*user.User, error) { return osUser, nil }, + listBrevKeys: register.ListBrevAuthorizedKeys, + removeKeyLine: register.RemoveAuthorizedKeyLine, + }, server +} + +// --- tests --- + +func Test_runRevokeSSH_NotCompatible(t *testing.T) { + osUser := tempUser(t) + regStore := &mockRegistrationStore{} + store := &mockRevokeSSHStore{token: "tok"} + svc := &fakeNodeService{} + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + deps.platform = mockPlatform{compatible: false} + + term := terminal.New() + err := runRevokeSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error for incompatible platform") + } +} + +func Test_runRevokeSSH_NotRegistered(t *testing.T) { + osUser := tempUser(t) + regStore := &mockRegistrationStore{} // no registration + store := &mockRevokeSSHStore{token: "tok"} + svc := &fakeNodeService{} + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + + term := terminal.New() + err := runRevokeSSH(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when not registered") + } +} + +func Test_runRevokeSSH_NoBrevKeys(t *testing.T) { + osUser := tempUser(t) + // Seed only non-brev keys. + seedAuthorizedKeys(t, osUser, "ssh-rsa EXISTING user@host\n") + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + }, + } + store := &mockRevokeSSHStore{token: "tok"} + svc := &fakeNodeService{} + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + + term := terminal.New() + err := runRevokeSSH(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected nil error for no brev keys, got: %v", err) + } +} + +func Test_runRevokeSSH_RevokeKeyWithUserID(t *testing.T) { + osUser := tempUser(t) + + keyLine := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIAliceKey # brev-cli user_id=user_2" + seedAuthorizedKeys(t, osUser, strings.Join([]string{ + "ssh-rsa EXISTING user@host", + keyLine, + "", + }, "\n")) + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + }, + } + store := &mockRevokeSSHStore{token: "tok"} + + var gotReq *nodev1.RevokeNodeSSHAccessRequest + svc := &fakeNodeService{ + revokeSSHFn: func(req *nodev1.RevokeNodeSSHAccessRequest) (*nodev1.RevokeNodeSSHAccessResponse, error) { + gotReq = req + return &nodev1.RevokeNodeSSHAccessResponse{}, nil + }, + } + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + + term := terminal.New() + err := runRevokeSSH(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runRevokeSSH failed: %v", err) + } + + // Verify key was removed from file. + remaining := readAuthorizedKeys(t, osUser) + if strings.Contains(remaining, "AliceKey") { + t.Errorf("expected key to be removed, still present:\n%s", remaining) + } + if !strings.Contains(remaining, "ssh-rsa EXISTING user@host") { + t.Errorf("non-brev key was removed:\n%s", remaining) + } + + // Verify RPC was called with correct params. + if gotReq == nil { + t.Fatal("expected RevokeNodeSSHAccess to be called") + } + if gotReq.GetExternalNodeId() != "unode_abc" { + t.Errorf("expected node ID unode_abc, got %s", gotReq.GetExternalNodeId()) + } + if gotReq.GetUserId() != "user_2" { + t.Errorf("expected user ID user_2, got %s", gotReq.GetUserId()) + } + if gotReq.GetLinuxUser() != "testuser" { + t.Errorf("expected linux user testuser, got %s", gotReq.GetLinuxUser()) + } +} + +func Test_runRevokeSSH_RevokeOldFormatKey_SkipsRPC(t *testing.T) { + osUser := tempUser(t) + + // Old-format key: no user_id in the tag. + keyLine := "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIBobKey # brev-cli" + seedAuthorizedKeys(t, osUser, strings.Join([]string{ + keyLine, + "", + }, "\n")) + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + }, + } + store := &mockRevokeSSHStore{token: "tok"} + + rpcCalled := false + svc := &fakeNodeService{ + revokeSSHFn: func(_ *nodev1.RevokeNodeSSHAccessRequest) (*nodev1.RevokeNodeSSHAccessResponse, error) { + rpcCalled = true + return &nodev1.RevokeNodeSSHAccessResponse{}, nil + }, + } + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + + term := terminal.New() + err := runRevokeSSH(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runRevokeSSH failed: %v", err) + } + + // Key should be removed from file. + remaining := readAuthorizedKeys(t, osUser) + if strings.Contains(remaining, "BobKey") { + t.Errorf("expected key to be removed, still present:\n%s", remaining) + } + + // RPC should NOT have been called. + if rpcCalled { + t.Error("expected RPC to be skipped for old-format key") + } +} + +func Test_runRevokeSSH_RPCFailure_StillRemovesKey(t *testing.T) { + osUser := tempUser(t) + + keyLine := "ssh-ed25519 AAAA # brev-cli user_id=user_3" + seedAuthorizedKeys(t, osUser, keyLine+"\n") + + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + }, + } + store := &mockRevokeSSHStore{token: "tok"} + + svc := &fakeNodeService{ + revokeSSHFn: func(_ *nodev1.RevokeNodeSSHAccessRequest) (*nodev1.RevokeNodeSSHAccessResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := baseDeps(t, svc, regStore, osUser) + defer server.Close() + + term := terminal.New() + // The command should succeed (key is removed) even though the RPC fails. + err := runRevokeSSH(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected no error (RPC failure is a warning), got: %v", err) + } + + remaining := readAuthorizedKeys(t, osUser) + if strings.Contains(remaining, "brev-cli") { + t.Errorf("key should be removed even when RPC fails:\n%s", remaining) + } +} + +func Test_truncateKey(t *testing.T) { + short := "ssh-rsa AAAA" + if got := truncateKey(short, 60); got != short { + t.Errorf("expected %q, got %q", short, got) + } + + long := strings.Repeat("x", 100) + got := truncateKey(long, 60) + if len(got) != 63 { // 60 + "..." + t.Errorf("expected length 63, got %d: %q", len(got), got) + } + if !strings.HasSuffix(got, "...") { + t.Errorf("expected suffix '...', got %q", got) + } +}