mirror of
https://github.com/jkaninda/go-storage.git
synced 2026-03-09 11:09:02 +01:00
351 lines
8.5 KiB
Go
351 lines
8.5 KiB
Go
package go_storage
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"os/user"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/pkg/sftp"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/agent"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
)
|
|
|
|
// SFTPStorage implements Storage interface for SFTP.
|
|
type SFTPStorage struct {
|
|
host string
|
|
baseDir string
|
|
sshClient *ssh.Client
|
|
client *sftp.Client
|
|
}
|
|
|
|
// NewSFTPStorage creates a new SFTP storage instance.
|
|
func NewSFTPStorage(config Config) (*SFTPStorage, error) {
|
|
if err := validateConfig(config, map[string]string{
|
|
"host": config.SFTPHost,
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("sftp: %w", err)
|
|
}
|
|
|
|
port := config.SFTPPort
|
|
if port == "" {
|
|
port = "22"
|
|
}
|
|
|
|
username := config.SFTPUsername
|
|
if username == "" {
|
|
currentUser, err := user.Current()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sftp: failed to get current user: %w", err)
|
|
}
|
|
username = currentUser.Username
|
|
}
|
|
|
|
password := config.SFTPPassword
|
|
if password == "" {
|
|
password = os.Getenv("PGBK_SSH_PASS")
|
|
}
|
|
|
|
// Build authentication methods
|
|
authMethods := make([]ssh.AuthMethod, 0)
|
|
|
|
// Password authentication (if no identity file is specified)
|
|
if config.SFTPIdentityFile == "" && password != "" {
|
|
authMethods = append(authMethods, ssh.Password(password))
|
|
}
|
|
|
|
// Public key authentication
|
|
signers, err := getSSHSigners(config.SFTPIdentityFile, password)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sftp: %w", err)
|
|
}
|
|
if len(signers) > 0 {
|
|
authMethods = append(authMethods, ssh.PublicKeys(signers...))
|
|
}
|
|
|
|
// Build SSH client config
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: username,
|
|
Auth: authMethods,
|
|
HostKeyCallback: getHostKeyCallback(config.SFTPIgnoreKnownHosts),
|
|
}
|
|
|
|
// Connect to SSH server
|
|
hostPort := fmt.Sprintf("%s:%s", config.SFTPHost, port)
|
|
sshClient, err := ssh.Dial("tcp", hostPort, sshConfig)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sftp: failed to connect to %s: %w", hostPort, err)
|
|
}
|
|
|
|
// Open SFTP session
|
|
sftpClient, err := sftp.NewClient(sshClient)
|
|
if err != nil {
|
|
sshClient.Close()
|
|
return nil, fmt.Errorf("sftp: failed to open sftp session: %w", err)
|
|
}
|
|
|
|
baseDir := config.SFTPDirectory
|
|
if baseDir == "" {
|
|
wd, err := sftpClient.Getwd()
|
|
if err != nil {
|
|
sftpClient.Close()
|
|
sshClient.Close()
|
|
return nil, fmt.Errorf("sftp: failed to get working directory: %w", err)
|
|
}
|
|
baseDir = wd
|
|
}
|
|
|
|
return &SFTPStorage{
|
|
host: config.SFTPHost,
|
|
baseDir: baseDir,
|
|
sshClient: sshClient,
|
|
client: sftpClient,
|
|
}, nil
|
|
}
|
|
|
|
// Upload uploads a file via SFTP.
|
|
func (s *SFTPStorage) Upload(ctx context.Context, localPath string, remotePath string) error {
|
|
src, err := os.Open(localPath)
|
|
if err != nil {
|
|
return fmt.Errorf("sftp: failed to open local file: %w", err)
|
|
}
|
|
defer src.Close()
|
|
|
|
fullRemotePath := filepath.Join(s.baseDir, remotePath)
|
|
fullRemotePath = normalizePathSeparators(fullRemotePath)
|
|
|
|
// Create parent directory if needed
|
|
remoteDir := filepath.Dir(fullRemotePath)
|
|
remoteDir = normalizePathSeparators(remoteDir)
|
|
|
|
if remoteDir != "." && remoteDir != "/" {
|
|
if err := s.client.MkdirAll(remoteDir); err != nil {
|
|
return fmt.Errorf("sftp: failed to create directory %s: %w", remoteDir, err)
|
|
}
|
|
}
|
|
|
|
dst, err := s.client.Create(fullRemotePath)
|
|
if err != nil {
|
|
return fmt.Errorf("sftp: failed to create remote file %s: %w", fullRemotePath, err)
|
|
}
|
|
defer dst.Close()
|
|
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
return fmt.Errorf("sftp: failed to transfer data: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Download downloads a file via SFTP.
|
|
func (s *SFTPStorage) Download(ctx context.Context, remotePath string, localPath string) error {
|
|
dst, err := os.Create(localPath)
|
|
if err != nil {
|
|
return fmt.Errorf("sftp: failed to create local file: %w", err)
|
|
}
|
|
defer dst.Close()
|
|
|
|
fullRemotePath := filepath.Join(s.baseDir, remotePath)
|
|
fullRemotePath = normalizePathSeparators(fullRemotePath)
|
|
|
|
src, err := s.client.Open(fullRemotePath)
|
|
if err != nil {
|
|
return fmt.Errorf("sftp: failed to open remote file %s: %w", fullRemotePath, err)
|
|
}
|
|
defer src.Close()
|
|
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
return fmt.Errorf("sftp: failed to transfer data: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// List lists files via SFTP with the given prefix.
|
|
func (s *SFTPStorage) List(ctx context.Context, prefix string) ([]Item, error) {
|
|
items := make([]Item, 0)
|
|
|
|
baseDir := normalizePathSeparators(s.baseDir)
|
|
|
|
walker := s.client.Walk(baseDir)
|
|
for walker.Step() {
|
|
if err := walker.Err(); err != nil {
|
|
return nil, fmt.Errorf("sftp: walk error: %w", err)
|
|
}
|
|
|
|
// Get relative path
|
|
path := walker.Path()
|
|
relPath, err := filepath.Rel(baseDir, path)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
// Normalize separators for comparison
|
|
relPath = normalizePathSeparators(relPath)
|
|
|
|
if !strings.HasPrefix(relPath, prefix) {
|
|
continue
|
|
}
|
|
|
|
stat := walker.Stat()
|
|
items = append(items, Item{
|
|
Key: relPath,
|
|
ModifiedTime: stat.ModTime(),
|
|
Size: stat.Size(),
|
|
IsDirectory: stat.IsDir(),
|
|
})
|
|
}
|
|
|
|
return items, nil
|
|
}
|
|
|
|
// Remove deletes a file via SFTP.
|
|
func (s *SFTPStorage) Remove(ctx context.Context, remotePath string) error {
|
|
fullRemotePath := filepath.Join(s.baseDir, remotePath)
|
|
fullRemotePath = normalizePathSeparators(fullRemotePath)
|
|
|
|
if err := s.client.Remove(fullRemotePath); err != nil {
|
|
return fmt.Errorf("sftp: failed to remove %s: %w", fullRemotePath, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close releases resources.
|
|
func (s *SFTPStorage) Close() error {
|
|
if err := s.client.Close(); err != nil {
|
|
s.sshClient.Close()
|
|
return err
|
|
}
|
|
return s.sshClient.Close()
|
|
}
|
|
|
|
// getSSHSigners returns SSH signers from identity file and SSH agent.
|
|
func getSSHSigners(identityFile string, passphrase string) ([]ssh.Signer, error) {
|
|
signers := make([]ssh.Signer, 0)
|
|
|
|
// Load identity file if provided
|
|
if identityFile != "" {
|
|
path, err := expandHomeDir(identityFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to expand identity file path: %w", err)
|
|
}
|
|
|
|
keyData, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read identity file %s: %w", path, err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(keyData)
|
|
if err != nil {
|
|
var passErr *ssh.PassphraseMissingError
|
|
if errors.As(err, &passErr) {
|
|
signer, err = ssh.ParsePrivateKeyWithPassphrase(keyData, []byte(passphrase))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt identity file: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("failed to parse identity file: %w", err)
|
|
}
|
|
}
|
|
|
|
signers = append(signers, signer)
|
|
}
|
|
|
|
// Try to get keys from SSH agent
|
|
socket := os.Getenv("SSH_AUTH_SOCK")
|
|
if socket != "" {
|
|
conn, err := net.Dial("unix", socket)
|
|
if err == nil {
|
|
agentClient := agent.NewClient(conn)
|
|
agentSigners, err := agentClient.Signers()
|
|
if err == nil {
|
|
signers = append(signers, agentSigners...)
|
|
}
|
|
}
|
|
}
|
|
|
|
return signers, nil
|
|
}
|
|
|
|
// getHostKeyCallback returns appropriate host key callback.
|
|
func getHostKeyCallback(ignoreHostKey bool) ssh.HostKeyCallback {
|
|
if ignoreHostKey {
|
|
return ssh.InsecureIgnoreHostKey()
|
|
}
|
|
|
|
knownHostsFiles := make([]string, 0)
|
|
for _, p := range []string{"/etc/ssh/ssh_known_hosts", "~/.ssh/known_hosts"} {
|
|
path, err := expandHomeDir(p)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if _, err := os.Stat(path); err == nil {
|
|
knownHostsFiles = append(knownHostsFiles, path)
|
|
}
|
|
}
|
|
|
|
if len(knownHostsFiles) == 0 {
|
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return fmt.Errorf("no known_hosts files found for host key verification")
|
|
}
|
|
}
|
|
|
|
callback, err := knownhosts.New(knownHostsFiles...)
|
|
if err != nil {
|
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
return fmt.Errorf("failed to load known_hosts: %w", err)
|
|
}
|
|
}
|
|
|
|
return callback
|
|
}
|
|
|
|
// expandHomeDir expands ~ in file paths to home directory.
|
|
func expandHomeDir(path string) (string, error) {
|
|
if !strings.HasPrefix(path, "~") {
|
|
return filepath.Clean(path), nil
|
|
}
|
|
|
|
parts := strings.SplitN(path, "/", 2)
|
|
username := strings.TrimPrefix(parts[0], "~")
|
|
|
|
var homeDir string
|
|
var err error
|
|
|
|
if username == "" {
|
|
homeDir, err = os.UserHomeDir()
|
|
if err != nil || homeDir == "" {
|
|
currentUser, err := user.Current()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to get home directory: %w", err)
|
|
}
|
|
homeDir = currentUser.HomeDir
|
|
}
|
|
} else {
|
|
userInfo, err := user.Lookup(username)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to lookup user %s: %w", username, err)
|
|
}
|
|
homeDir = userInfo.HomeDir
|
|
}
|
|
|
|
if homeDir == "" {
|
|
return "", fmt.Errorf("empty home directory")
|
|
}
|
|
|
|
if len(parts) == 1 {
|
|
return homeDir, nil
|
|
}
|
|
|
|
return filepath.Join(homeDir, parts[1]), nil
|
|
}
|