Files
go-storage/sftp.go

351 lines
8.5 KiB
Go

package go_storage
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"path/filepath"
"strings"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
// SFTPStorage implements Storage interface for SFTP.
type SFTPStorage struct {
host string
baseDir string
sshClient *ssh.Client
client *sftp.Client
}
// NewSFTPStorage creates a new SFTP storage instance.
func NewSFTPStorage(config Config) (*SFTPStorage, error) {
if err := validateConfig(config, map[string]string{
"host": config.SFTPHost,
}); err != nil {
return nil, fmt.Errorf("sftp: %w", err)
}
port := config.SFTPPort
if port == "" {
port = "22"
}
username := config.SFTPUsername
if username == "" {
currentUser, err := user.Current()
if err != nil {
return nil, fmt.Errorf("sftp: failed to get current user: %w", err)
}
username = currentUser.Username
}
password := config.SFTPPassword
if password == "" {
password = os.Getenv("PGBK_SSH_PASS")
}
// Build authentication methods
authMethods := make([]ssh.AuthMethod, 0)
// Password authentication (if no identity file is specified)
if config.SFTPIdentityFile == "" && password != "" {
authMethods = append(authMethods, ssh.Password(password))
}
// Public key authentication
signers, err := getSSHSigners(config.SFTPIdentityFile, password)
if err != nil {
return nil, fmt.Errorf("sftp: %w", err)
}
if len(signers) > 0 {
authMethods = append(authMethods, ssh.PublicKeys(signers...))
}
// Build SSH client config
sshConfig := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: getHostKeyCallback(config.SFTPIgnoreKnownHosts),
}
// Connect to SSH server
hostPort := fmt.Sprintf("%s:%s", config.SFTPHost, port)
sshClient, err := ssh.Dial("tcp", hostPort, sshConfig)
if err != nil {
return nil, fmt.Errorf("sftp: failed to connect to %s: %w", hostPort, err)
}
// Open SFTP session
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
sshClient.Close()
return nil, fmt.Errorf("sftp: failed to open sftp session: %w", err)
}
baseDir := config.SFTPDirectory
if baseDir == "" {
wd, err := sftpClient.Getwd()
if err != nil {
sftpClient.Close()
sshClient.Close()
return nil, fmt.Errorf("sftp: failed to get working directory: %w", err)
}
baseDir = wd
}
return &SFTPStorage{
host: config.SFTPHost,
baseDir: baseDir,
sshClient: sshClient,
client: sftpClient,
}, nil
}
// Upload uploads a file via SFTP.
func (s *SFTPStorage) Upload(ctx context.Context, localPath string, remotePath string) error {
src, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("sftp: failed to open local file: %w", err)
}
defer src.Close()
fullRemotePath := filepath.Join(s.baseDir, remotePath)
fullRemotePath = normalizePathSeparators(fullRemotePath)
// Create parent directory if needed
remoteDir := filepath.Dir(fullRemotePath)
remoteDir = normalizePathSeparators(remoteDir)
if remoteDir != "." && remoteDir != "/" {
if err := s.client.MkdirAll(remoteDir); err != nil {
return fmt.Errorf("sftp: failed to create directory %s: %w", remoteDir, err)
}
}
dst, err := s.client.Create(fullRemotePath)
if err != nil {
return fmt.Errorf("sftp: failed to create remote file %s: %w", fullRemotePath, err)
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
return fmt.Errorf("sftp: failed to transfer data: %w", err)
}
return nil
}
// Download downloads a file via SFTP.
func (s *SFTPStorage) Download(ctx context.Context, remotePath string, localPath string) error {
dst, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("sftp: failed to create local file: %w", err)
}
defer dst.Close()
fullRemotePath := filepath.Join(s.baseDir, remotePath)
fullRemotePath = normalizePathSeparators(fullRemotePath)
src, err := s.client.Open(fullRemotePath)
if err != nil {
return fmt.Errorf("sftp: failed to open remote file %s: %w", fullRemotePath, err)
}
defer src.Close()
if _, err := io.Copy(dst, src); err != nil {
return fmt.Errorf("sftp: failed to transfer data: %w", err)
}
return nil
}
// List lists files via SFTP with the given prefix.
func (s *SFTPStorage) List(ctx context.Context, prefix string) ([]Item, error) {
items := make([]Item, 0)
baseDir := normalizePathSeparators(s.baseDir)
walker := s.client.Walk(baseDir)
for walker.Step() {
if err := walker.Err(); err != nil {
return nil, fmt.Errorf("sftp: walk error: %w", err)
}
// Get relative path
path := walker.Path()
relPath, err := filepath.Rel(baseDir, path)
if err != nil {
continue
}
// Normalize separators for comparison
relPath = normalizePathSeparators(relPath)
if !strings.HasPrefix(relPath, prefix) {
continue
}
stat := walker.Stat()
items = append(items, Item{
Key: relPath,
ModifiedTime: stat.ModTime(),
Size: stat.Size(),
IsDirectory: stat.IsDir(),
})
}
return items, nil
}
// Remove deletes a file via SFTP.
func (s *SFTPStorage) Remove(ctx context.Context, remotePath string) error {
fullRemotePath := filepath.Join(s.baseDir, remotePath)
fullRemotePath = normalizePathSeparators(fullRemotePath)
if err := s.client.Remove(fullRemotePath); err != nil {
return fmt.Errorf("sftp: failed to remove %s: %w", fullRemotePath, err)
}
return nil
}
// Close releases resources.
func (s *SFTPStorage) Close() error {
if err := s.client.Close(); err != nil {
s.sshClient.Close()
return err
}
return s.sshClient.Close()
}
// getSSHSigners returns SSH signers from identity file and SSH agent.
func getSSHSigners(identityFile string, passphrase string) ([]ssh.Signer, error) {
signers := make([]ssh.Signer, 0)
// Load identity file if provided
if identityFile != "" {
path, err := expandHomeDir(identityFile)
if err != nil {
return nil, fmt.Errorf("failed to expand identity file path: %w", err)
}
keyData, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read identity file %s: %w", path, err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
var passErr *ssh.PassphraseMissingError
if errors.As(err, &passErr) {
signer, err = ssh.ParsePrivateKeyWithPassphrase(keyData, []byte(passphrase))
if err != nil {
return nil, fmt.Errorf("failed to decrypt identity file: %w", err)
}
} else {
return nil, fmt.Errorf("failed to parse identity file: %w", err)
}
}
signers = append(signers, signer)
}
// Try to get keys from SSH agent
socket := os.Getenv("SSH_AUTH_SOCK")
if socket != "" {
conn, err := net.Dial("unix", socket)
if err == nil {
agentClient := agent.NewClient(conn)
agentSigners, err := agentClient.Signers()
if err == nil {
signers = append(signers, agentSigners...)
}
}
}
return signers, nil
}
// getHostKeyCallback returns appropriate host key callback.
func getHostKeyCallback(ignoreHostKey bool) ssh.HostKeyCallback {
if ignoreHostKey {
return ssh.InsecureIgnoreHostKey()
}
knownHostsFiles := make([]string, 0)
for _, p := range []string{"/etc/ssh/ssh_known_hosts", "~/.ssh/known_hosts"} {
path, err := expandHomeDir(p)
if err != nil {
continue
}
if _, err := os.Stat(path); err == nil {
knownHostsFiles = append(knownHostsFiles, path)
}
}
if len(knownHostsFiles) == 0 {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return fmt.Errorf("no known_hosts files found for host key verification")
}
}
callback, err := knownhosts.New(knownHostsFiles...)
if err != nil {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return fmt.Errorf("failed to load known_hosts: %w", err)
}
}
return callback
}
// expandHomeDir expands ~ in file paths to home directory.
func expandHomeDir(path string) (string, error) {
if !strings.HasPrefix(path, "~") {
return filepath.Clean(path), nil
}
parts := strings.SplitN(path, "/", 2)
username := strings.TrimPrefix(parts[0], "~")
var homeDir string
var err error
if username == "" {
homeDir, err = os.UserHomeDir()
if err != nil || homeDir == "" {
currentUser, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err)
}
homeDir = currentUser.HomeDir
}
} else {
userInfo, err := user.Lookup(username)
if err != nil {
return "", fmt.Errorf("failed to lookup user %s: %w", username, err)
}
homeDir = userInfo.HomeDir
}
if homeDir == "" {
return "", fmt.Errorf("empty home directory")
}
if len(parts) == 1 {
return homeDir, nil
}
return filepath.Join(homeDir, parts[1]), nil
}