package main

import (
	"bytes"
	"context"
	"crypto/ed25519"
	"crypto/hmac"
	"crypto/rand"
	"crypto/sha256"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/base64"
	"encoding/json"
	"encoding/pem"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"net/url"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"sync/atomic"
	"time"

	"github.com/gpuaas/platform/packages/shared/buildinfo"
)

type nodeTask struct {
	TaskID        string         `json:"task_id"`
	TaskType      string         `json:"task_type"`
	NodeID        string         `json:"node_id"`
	CorrelationID string         `json:"correlation_id"`
	IssuedAt      string         `json:"issued_at"`
	ExpiresAt     string         `json:"expires_at"`
	Params        map[string]any `json:"params"`
	Signature     string         `json:"signature"`
}

type nodeTaskResult struct {
	Status      string         `json:"status"`
	Error       string         `json:"error,omitempty"`
	CompletedAt string         `json:"completed_at"`
	Output      map[string]any `json:"output"`
}

type agent struct {
	cfg              Config
	httpClient       *http.Client
	transport        *http.Transport
	log              *slog.Logger
	lastCertFP       string
	lastRecoveryFP   string
	lastRecoveryAt   time.Time
	pollFailureLog   pollFailureLogState
	enrollFailureLog pollFailureLogState
	metrics          nodeAgentMetrics
	guestTelemetry   *guestTelemetryStore
}

type nodeAgentMetrics struct {
	taskPollAttempts                  atomic.Uint64
	taskPollSkippedNoCert             atomic.Uint64
	taskPollFailures                  atomic.Uint64
	taskClaimsTotal                   atomic.Uint64
	taskResultPostFailures            atomic.Uint64
	taskCompletedTotal                atomic.Uint64
	taskFailedTotal                   atomic.Uint64
	taskRejectedTotal                 atomic.Uint64
	identityRejectionsTotal           atomic.Uint64
	recoveryEnrollmentAttemptsTotal   atomic.Uint64
	recoveryEnrollmentSuccessTotal    atomic.Uint64
	recoveryEnrollmentFailuresTotal   atomic.Uint64
	recoveryEnrollmentSuppressedTotal atomic.Uint64
}

const (
	nodeIdentityRevokedPollBackoff       = 5 * time.Minute
	nodeIdentityRevokedRenewBackoff      = 30 * time.Minute
	certExpiredRecoveryTokenRequiredCode = "cert_expired_recovery_token_required"
	minTaskPollRetryDelay                = time.Second
	maxTaskPollRetryDelay                = time.Minute
	credentialTaskPollBackoff            = 5 * time.Minute
	rateLimitedTaskPollBackoff           = 5 * time.Minute
	pollFailureSummaryInterval           = 5 * time.Minute
)

var ErrNodeIdentityRevoked = errors.New("node identity revoked")
var ErrLocalCertExpired = errors.New("local certificate expired")

func newAgent(cfg Config, logger *slog.Logger) *agent {
	if strings.TrimSpace(cfg.InstanceID) == "" {
		cfg.InstanceID = generateNodeInstanceID()
	}
	transport := &http.Transport{}
	if strings.HasPrefix(cfg.APIURL, "https://") {
		tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
		roots, err := loadRootCAs(cfg.CABundlePath)
		if err != nil {
			logger.Warn("failed to load custom ca bundle; falling back to system trust", "path", cfg.CABundlePath, "error", err)
		} else {
			tlsConfig.RootCAs = roots
		}
		tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
			if !fileExists(cfg.CertPath) || !fileExists(cfg.KeyPath) {
				return &tls.Certificate{}, nil
			}
			cert, err := tls.LoadX509KeyPair(cfg.CertPath, cfg.KeyPath)
			if err != nil {
				return &tls.Certificate{}, nil
			}
			return &cert, nil
		}
		transport.TLSClientConfig = tlsConfig
	}
	client := &http.Client{Timeout: cfg.PollTimeout + (5 * time.Second), Transport: transport}
	return &agent{
		cfg:            cfg,
		httpClient:     client,
		transport:      transport,
		log:            logger,
		guestTelemetry: newGuestTelemetryStore(),
	}
}

func (a *agent) enrollmentHTTPClient() *http.Client {
	transport := http.DefaultTransport.(*http.Transport).Clone()
	if a.transport != nil {
		transport = a.transport.Clone()
	}
	if strings.HasPrefix(a.recoveryAPIURL(), "https://") {
		tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
		if transport.TLSClientConfig != nil {
			tlsConfig = transport.TLSClientConfig.Clone()
		}
		if recoveryCAPath := a.recoveryCABundlePath(); strings.TrimSpace(recoveryCAPath) != "" {
			roots, err := loadRootCAs(recoveryCAPath)
			if err != nil {
				a.log.Warn("failed to load recovery ca bundle; falling back to existing trust", "path", recoveryCAPath, "error", err)
			} else {
				tlsConfig.RootCAs = roots
			}
		}
		tlsConfig.GetClientCertificate = nil
		tlsConfig.Certificates = nil
		transport.TLSClientConfig = tlsConfig
	} else if transport.TLSClientConfig != nil {
		tlsConfig := transport.TLSClientConfig.Clone()
		tlsConfig.GetClientCertificate = nil
		tlsConfig.Certificates = nil
		transport.TLSClientConfig = tlsConfig
	}
	return &http.Client{
		Timeout:   a.httpClient.Timeout,
		Transport: transport,
	}
}

func (a *agent) recoveryAPIURL() string {
	if raw := strings.TrimRight(strings.TrimSpace(a.cfg.RecoveryAPIURL), "/"); raw != "" {
		return raw
	}
	return strings.TrimRight(strings.TrimSpace(a.cfg.APIURL), "/")
}

func (a *agent) recoveryCABundlePath() string {
	if raw := strings.TrimSpace(a.cfg.RecoveryCABundlePath); raw != "" {
		return raw
	}
	return strings.TrimSpace(a.cfg.CABundlePath)
}

func generateNodeInstanceID() string {
	var raw [12]byte
	if _, err := rand.Read(raw[:]); err != nil {
		return fmt.Sprintf("pid-%d-%d", os.Getpid(), time.Now().UTC().UnixNano())
	}
	return base64.RawURLEncoding.EncodeToString(raw[:])
}

func (a *agent) run(ctx context.Context) error {
	enrolledThisStart, err := a.maybeEnroll(ctx)
	if err != nil {
		a.log.Warn("node enrollment attempt failed", "error", err)
	}
	if enrolledThisStart {
		a.log.Info("skipping immediate certificate renewal after fresh enrollment", "node_id", a.cfg.NodeID)
	} else if err := a.maybeRenewCert(ctx); err != nil {
		a.log.Warn("certificate renewal check failed", "error", err)
	}

	pollTimer := time.NewTimer(0)
	enrollTimer := time.NewTimer(30 * time.Second)
	renewTimer := time.NewTimer(5 * time.Minute)
	defer pollTimer.Stop()
	defer enrollTimer.Stop()
	defer renewTimer.Stop()
	pollFailureCount := 0
	enrollFailureCount := 0
	renewFailureCount := 0
	revokedPollCount := 0
	revokedRenewCount := 0

	for {
		select {
		case <-ctx.Done():
			return nil
		case <-pollTimer.C:
			a.metrics.taskPollAttempts.Add(1)
			if !fileExists(a.cfg.CertPath) || !fileExists(a.cfg.KeyPath) {
				a.metrics.taskPollSkippedNoCert.Add(1)
				pollFailureCount++
				backoff := pollBackoffDuration(pollFailureCount)
				if pollFailureCount == 1 || pollFailureCount%5 == 0 {
					a.log.Warn("task poll skipped until node certificate enrollment succeeds", "failure_count", pollFailureCount, "next_poll_after", backoff.String())
				}
				pollTimer.Reset(backoff)
				continue
			}
			if err := a.pollOnce(ctx); err != nil {
				if a.cfg.EnrollmentToken != "" && (errors.Is(err, ErrLocalCertExpired) || isCredentialPollFailure(err)) {
					if suppressed, fingerprint, nextRetryAfter := a.recentRecoveredCertStillRejected(); suppressed {
						a.metrics.recoveryEnrollmentSuppressedTotal.Add(1)
						pollFailureCount++
						a.metrics.taskPollFailures.Add(1)
						a.log.Warn("node credential still rejected after recovery enrollment; backing off before another enrollment attempt",
							"error", err,
							"fingerprint_sha256", fingerprint,
							"failure_count", pollFailureCount,
							"next_poll_after", nextRetryAfter.String(),
						)
						pollTimer.Reset(nextRetryAfter)
						continue
					}
					a.log.Warn("node credential failed during task polling; attempting recovery enrollment", "error", err)
					enrolled, enrollErr := a.recoverEnroll(ctx, "task_poll_credential_failure")
					if enrollErr == nil {
						pollFailureCount = 0
						enrollFailureCount = 0
						if enrolled {
							a.log.Info("certificate recovered via enrollment token", "node_id", a.cfg.NodeID)
						}
						pollTimer.Reset(1 * time.Second)
						continue
					}
					err = fmt.Errorf("%w; recovery enrollment failed: %v", err, enrollErr)
				}
				if errors.Is(err, ErrNodeIdentityRevoked) {
					a.metrics.identityRejectionsTotal.Add(1)
					recovered, recoverErr := a.tryRecoverRejectedIdentity(ctx, "task_poll_identity_rejected", err)
					if recoverErr != nil {
						err = fmt.Errorf("%w; recovery enrollment failed: %v", err, recoverErr)
					}
					if recovered {
						pollFailureCount = 0
						enrollFailureCount = 0
						revokedPollCount = 0
						a.log.Info("node identity recovered via enrollment token; resuming task polling", "node_id", a.cfg.NodeID)
						pollTimer.Reset(1 * time.Second)
						continue
					}
					revokedPollCount++
					a.writeLocalDiagnostic("task_poll", err, revokedPollCount, nodeIdentityRevokedPollBackoff)
					if revokedPollCount == 1 || revokedPollCount%12 == 0 {
						a.log.Warn("node identity revoked; pausing task polling until node is reactivated or re-enrolled", "error", err, "attempts", revokedPollCount, "next_poll_after", nodeIdentityRevokedPollBackoff.String())
					}
					pollTimer.Reset(nodeIdentityRevokedPollBackoff)
					continue
				}
				revokedPollCount = 0
				a.metrics.taskPollFailures.Add(1)
				pollFailureCount++
				backoff := pollBackoffDuration(pollFailureCount)
				if isCredentialPollFailure(err) && backoff < credentialTaskPollBackoff {
					backoff = credentialTaskPollBackoff
				}
				if isRateLimitedPollFailure(err) && backoff < rateLimitedTaskPollBackoff {
					backoff = rateLimitedTaskPollBackoff
				}
				if backoff < minTaskPollRetryDelay {
					backoff = minTaskPollRetryDelay
				}
				a.logTaskPollFailure(err, pollFailureCount, backoff)
				pollTimer.Reset(backoff)
				continue
			}
			revokedPollCount = 0
			pollFailureCount = 0
			a.logTaskPollRecovered()
			pollTimer.Reset(1 * time.Second)
		case <-enrollTimer.C:
			enrolledThisAttempt, err := a.maybeEnroll(ctx)
			if err != nil {
				enrollFailureCount++
				backoff := maintenanceBackoffDuration(30*time.Second, 10*time.Minute, enrollFailureCount)
				a.logNodeEnrollmentFailure(err, enrollFailureCount, backoff)
				enrollTimer.Reset(backoff)
				continue
			}
			enrollFailureCount = 0
			a.logNodeEnrollmentRecovered()
			if enrolledThisAttempt {
				a.log.Info("skipping immediate certificate renewal after fresh enrollment", "node_id", a.cfg.NodeID)
			}
			enrollTimer.Reset(30 * time.Second)
		case <-renewTimer.C:
			if err := a.maybeRenewCert(ctx); err != nil {
				if errors.Is(err, ErrNodeIdentityRevoked) {
					a.metrics.identityRejectionsTotal.Add(1)
					recovered, recoverErr := a.tryRecoverRejectedIdentity(ctx, "certificate_renew_identity_rejected", err)
					if recoverErr != nil {
						err = fmt.Errorf("%w; recovery enrollment failed: %v", err, recoverErr)
					}
					if recovered {
						renewFailureCount = 0
						enrollFailureCount = 0
						revokedRenewCount = 0
						a.log.Info("node identity recovered via enrollment token; resuming certificate renewal checks", "node_id", a.cfg.NodeID)
						renewTimer.Reset(5 * time.Minute)
						continue
					}
					revokedRenewCount++
					a.writeLocalDiagnostic("renewal", err, revokedRenewCount, nodeIdentityRevokedRenewBackoff)
					if revokedRenewCount == 1 || revokedRenewCount%3 == 0 {
						a.log.Warn("node identity revoked; skipping certificate renewal until node is reactivated or re-enrolled", "error", err, "attempts", revokedRenewCount, "next_retry_after", nodeIdentityRevokedRenewBackoff.String())
					}
					renewTimer.Reset(nodeIdentityRevokedRenewBackoff)
					continue
				}
				if a.cfg.EnrollmentToken != "" && (errors.Is(err, ErrLocalCertExpired) || isCredentialPollFailure(err)) {
					if suppressed, fingerprint, nextRetryAfter := a.recentRecoveredCertStillRejected(); suppressed {
						a.metrics.recoveryEnrollmentSuppressedTotal.Add(1)
						renewFailureCount++
						a.log.Warn("node credential still rejected during certificate renewal after recovery enrollment; backing off before another enrollment attempt",
							"error", err,
							"fingerprint_sha256", fingerprint,
							"failure_count", renewFailureCount,
							"next_retry_after", nextRetryAfter.String(),
						)
						renewTimer.Reset(nextRetryAfter)
						continue
					}
					a.log.Warn("node credential failed during certificate renewal; attempting recovery enrollment", "error", err)
					enrolled, enrollErr := a.recoverEnroll(ctx, "certificate_renew_credential_failure")
					if enrollErr == nil {
						renewFailureCount = 0
						enrollFailureCount = 0
						if enrolled {
							a.log.Info("certificate recovered via enrollment token", "node_id", a.cfg.NodeID)
						}
						renewTimer.Reset(5 * time.Minute)
						continue
					}
					err = fmt.Errorf("%w; recovery enrollment failed: %v", err, enrollErr)
				}
				revokedRenewCount = 0
				renewFailureCount++
				backoff := maintenanceBackoffDuration(5*time.Minute, 30*time.Minute, renewFailureCount)
				a.writeLocalDiagnostic("renewal", err, renewFailureCount, backoff)
				a.log.Warn("certificate renewal check failed", "error", err, "reason_code", classifyNodeAgentLocalFailure("renewal", err), "failure_count", renewFailureCount, "next_retry_after", backoff.String())
				renewTimer.Reset(backoff)
				continue
			}
			revokedRenewCount = 0
			renewFailureCount = 0
			renewTimer.Reset(5 * time.Minute)
		}
	}
}

func pollBackoffDuration(failureCount int) time.Duration {
	if failureCount <= 0 {
		return minTaskPollRetryDelay
	}
	backoff := minTaskPollRetryDelay
	for i := 1; i < failureCount; i++ {
		if backoff >= maxTaskPollRetryDelay/2 {
			return maxTaskPollRetryDelay
		}
		backoff *= 2
	}
	return normalizeTaskPollDelay(backoff)
}

func normalizeTaskPollDelay(delay time.Duration) time.Duration {
	if delay < minTaskPollRetryDelay {
		return minTaskPollRetryDelay
	}
	if delay > maxTaskPollRetryDelay {
		return maxTaskPollRetryDelay
	}
	return delay
}

func isCredentialPollFailure(err error) bool {
	if err == nil {
		return false
	}
	if errors.Is(err, ErrLocalCertExpired) {
		return true
	}
	msg := strings.ToLower(err.Error())
	return strings.Contains(msg, "certificate") ||
		strings.Contains(msg, "tls:") ||
		strings.Contains(msg, "ssl certificate") ||
		strings.Contains(msg, certExpiredRecoveryTokenRequiredCode)
}

func isRateLimitedPollFailure(err error) bool {
	if err == nil {
		return false
	}
	msg := strings.ToLower(err.Error())
	return strings.Contains(msg, "status=429") ||
		strings.Contains(msg, "rate_limit_exceeded") ||
		strings.Contains(msg, "rate limited") ||
		strings.Contains(msg, "temporarily rate limited") ||
		strings.Contains(msg, "error 1200")
}

type pollFailureLogState struct {
	fingerprint string
	firstAt     time.Time
	lastAt      time.Time
	count       int
}

func (a *agent) logTaskPollFailure(err error, failureCount int, nextPollAfter time.Duration) {
	a.logConnectivityFailure(&a.pollFailureLog, "task_poll", "task poll failed", "task poll still failing", "next_poll_after", err, failureCount, nextPollAfter)
}

func (a *agent) logTaskPollRecovered() {
	a.logConnectivityRecovered(&a.pollFailureLog, "task poll recovered")
}

func (a *agent) logNodeEnrollmentFailure(err error, failureCount int, nextRetryAfter time.Duration) {
	a.logConnectivityFailure(&a.enrollFailureLog, "enrollment", "node enrollment retry failed", "node enrollment retry still failing", "next_retry_after", err, failureCount, nextRetryAfter)
}

func (a *agent) logNodeEnrollmentRecovered() {
	a.logConnectivityRecovered(&a.enrollFailureLog, "node enrollment recovered")
}

func (a *agent) logConnectivityFailure(state *pollFailureLogState, operation, firstMessage, summaryMessage, nextDelayKey string, err error, failureCount int, nextDelay time.Duration) {
	reasonCode := classifyNodeAgentLocalFailure(operation, err)
	if state == nil {
		a.log.Warn(firstMessage, "error", err, "reason_code", reasonCode, "failure_count", failureCount, nextDelayKey, nextDelay.String())
		a.writeLocalDiagnostic(operation, err, failureCount, nextDelay)
		return
	}
	now := time.Now().UTC()
	fingerprint := pollFailureFingerprint(err)
	if state.fingerprint != fingerprint {
		*state = pollFailureLogState{
			fingerprint: fingerprint,
			firstAt:     now,
			lastAt:      now,
			count:       1,
		}
		a.log.Warn(firstMessage,
			"error", err,
			"reason_code", reasonCode,
			"failure_count", failureCount,
			"fingerprint", fingerprint,
			nextDelayKey, nextDelay.String(),
		)
		a.writeLocalDiagnostic(operation, err, state.count, nextDelay)
		return
	}
	state.count++
	a.writeLocalDiagnostic(operation, err, state.count, nextDelay)
	if now.Sub(state.lastAt) < pollFailureSummaryInterval {
		return
	}
	state.lastAt = now
	a.log.Warn(summaryMessage,
		"error", err,
		"reason_code", reasonCode,
		"failure_count", failureCount,
		"fingerprint", fingerprint,
		"same_error_count", state.count,
		"duration", now.Sub(state.firstAt).Round(time.Second).String(),
		nextDelayKey, nextDelay.String(),
	)
}

func (a *agent) logConnectivityRecovered(state *pollFailureLogState, message string) {
	if state == nil || state.count == 0 {
		return
	}
	now := time.Now().UTC()
	a.log.Info(message,
		"fingerprint", state.fingerprint,
		"suppressed_error_count", max(0, state.count-1),
		"duration", now.Sub(state.firstAt).Round(time.Second).String(),
	)
	*state = pollFailureLogState{}
}

type nodeAgentLocalDiagnostic struct {
	NodeID          string `json:"node_id"`
	Operation       string `json:"operation"`
	ReasonCode      string `json:"reason_code"`
	Message         string `json:"message"`
	APIHost         string `json:"api_host,omitempty"`
	RecoveryHost    string `json:"recovery_host,omitempty"`
	TerminalHost    string `json:"terminal_host,omitempty"`
	CertFingerprint string `json:"cert_fingerprint_sha256,omitempty"`
	FirstSeenAt     string `json:"first_seen_at"`
	LastSeenAt      string `json:"last_seen_at"`
	Count           int    `json:"count"`
	NextRetryAfter  string `json:"next_retry_after,omitempty"`
	AgentVersion    string `json:"agent_version,omitempty"`
}

func (a *agent) writeLocalDiagnostic(operation string, err error, count int, nextRetryAfter time.Duration) {
	path := strings.TrimSpace(a.cfg.DiagnosticPath)
	if path == "" || err == nil {
		return
	}
	reasonCode := classifyNodeAgentLocalFailure(operation, err)
	now := time.Now().UTC()
	firstSeenAt := now
	var previous nodeAgentLocalDiagnostic
	if raw, readErr := os.ReadFile(path); readErr == nil && len(raw) > 0 {
		_ = json.Unmarshal(raw, &previous)
		if previous.Operation == strings.TrimSpace(operation) && previous.ReasonCode == reasonCode {
			if parsed, parseErr := time.Parse(time.RFC3339, previous.FirstSeenAt); parseErr == nil {
				firstSeenAt = parsed
			}
		}
	}
	certFingerprint := ""
	if summary, summaryErr := currentCertSummary(a.cfg.CertPath); summaryErr == nil {
		certFingerprint = summary.Fingerprint
	}
	diagnostic := nodeAgentLocalDiagnostic{
		NodeID:          strings.TrimSpace(a.cfg.NodeID),
		Operation:       strings.TrimSpace(operation),
		ReasonCode:      reasonCode,
		Message:         sanitizeDiagnosticMessage(err),
		APIHost:         urlHost(a.cfg.APIURL),
		RecoveryHost:    urlHost(a.recoveryAPIURL()),
		TerminalHost:    urlHost(a.cfg.TerminalAPIURL),
		CertFingerprint: certFingerprint,
		FirstSeenAt:     firstSeenAt.Format(time.RFC3339),
		LastSeenAt:      now.Format(time.RFC3339),
		Count:           max(1, count),
		AgentVersion:    strings.TrimSpace(buildinfo.Version),
	}
	if nextRetryAfter > 0 {
		diagnostic.NextRetryAfter = nextRetryAfter.String()
	}
	raw, err := json.MarshalIndent(diagnostic, "", "  ")
	if err != nil {
		a.log.Warn("node-agent local diagnostic encode failed", "error", err)
		return
	}
	if err := os.MkdirAll(filepath.Dir(path), 0o750); err != nil {
		a.log.Warn("node-agent local diagnostic directory create failed", "path", path, "error", err)
		return
	}
	tmp := path + ".tmp"
	if err := os.WriteFile(tmp, append(raw, '\n'), 0o640); err != nil {
		a.log.Warn("node-agent local diagnostic write failed", "path", path, "error", err)
		return
	}
	if err := os.Rename(tmp, path); err != nil {
		a.log.Warn("node-agent local diagnostic rename failed", "path", path, "error", err)
	}
}

func classifyNodeAgentLocalFailure(operation string, err error) string {
	if err == nil {
		return "healthy"
	}
	msg := strings.ToLower(err.Error())
	switch {
	case strings.Contains(msg, "recovery enrollment failed") && (strings.Contains(msg, "x509:") || strings.Contains(msg, "certificate") || strings.Contains(msg, "tls:")):
		return "recovery_enrollment_blocked"
	case errors.Is(err, ErrNodeIdentityRevoked) || strings.Contains(msg, "node_not_found") || strings.Contains(msg, "identity revoked"):
		return "identity_revoked"
	case isRateLimitedPollFailure(err):
		return "edge_rate_limited"
	case errors.Is(err, ErrLocalCertExpired) || strings.Contains(msg, "cert_expired"):
		return "cert_expired"
	case strings.Contains(msg, "certificate is valid for") || strings.Contains(msg, "not ") && strings.Contains(msg, "certificate"):
		return "endpoint_profile_drift"
	case strings.Contains(msg, "x509:") || strings.Contains(msg, "unknown authority") || strings.Contains(msg, "tls:") || strings.Contains(msg, "ssl certificate"):
		return "server_tls_untrusted"
	case strings.Contains(msg, "no such host") || strings.Contains(msg, "connection refused") || strings.Contains(msg, "i/o timeout") || strings.Contains(msg, "context deadline exceeded"):
		return "endpoint_unreachable"
	case strings.Contains(msg, "enrollment_token") || strings.Contains(msg, "recovery token"):
		return "recovery_token_missing"
	}
	switch strings.TrimSpace(operation) {
	case "enrollment":
		return "enrollment_failed"
	case "renewal":
		return "cert_renew_failed"
	default:
		return "task_poll_failed"
	}
}

func sanitizeDiagnosticMessage(err error) string {
	if err == nil {
		return ""
	}
	msg := strings.TrimSpace(err.Error())
	if idx := strings.Index(strings.ToLower(msg), "body="); idx >= 0 {
		msg = msg[:idx] + "body=[REDACTED]"
	}
	msg = strings.Join(strings.Fields(msg), " ")
	if len(msg) > 240 {
		msg = msg[:240]
	}
	return msg
}

func urlHost(raw string) string {
	u, err := url.Parse(strings.TrimSpace(raw))
	if err != nil {
		return ""
	}
	return u.Host
}

func pollFailureFingerprint(err error) string {
	if err == nil {
		return "none"
	}
	msg := strings.ToLower(strings.TrimSpace(err.Error()))
	if idx := strings.Index(msg, "body="); idx >= 0 {
		msg = msg[:idx]
	}
	msg = strings.Join(strings.Fields(msg), " ")
	if len(msg) > 160 {
		msg = msg[:160]
	}
	if msg == "" {
		return "unknown"
	}
	return msg
}

func maintenanceBackoffDuration(base, max time.Duration, failureCount int) time.Duration {
	if failureCount <= 0 {
		return base
	}
	backoff := base
	for i := 1; i < failureCount; i++ {
		if backoff >= max/2 {
			return max
		}
		backoff *= 2
		if backoff >= max {
			return max
		}
	}
	if backoff > max {
		return max
	}
	return backoff
}

func (a *agent) maybeEnroll(ctx context.Context) (bool, error) {
	if fileExists(a.cfg.CertPath) && fileExists(a.cfg.KeyPath) {
		state, err := localCertNodeState(a.cfg.CertPath, a.cfg.NodeID, time.Now().UTC())
		if err == nil && state.matches && !state.expired {
			if a.cfg.EnrollmentToken != "" && strings.TrimSpace(a.cfg.NodeCertCABundlePath) != "" && fileExists(a.cfg.NodeCertCABundlePath) {
				trusted, trustErr := localCertChainsToBundle(a.cfg.CertPath, a.cfg.NodeCertCABundlePath, time.Now().UTC())
				if trustErr == nil && !trusted {
					a.log.Warn("local certificate does not chain to current node CA bundle; attempting re-enrollment", "node_id", a.cfg.NodeID, "cert_path", a.cfg.CertPath, "node_cert_ca_bundle_path", a.cfg.NodeCertCABundlePath)
				} else if trustErr != nil {
					a.log.Warn("local certificate trust check failed; keeping existing certificate", "node_id", a.cfg.NodeID, "cert_path", a.cfg.CertPath, "node_cert_ca_bundle_path", a.cfg.NodeCertCABundlePath, "error", trustErr)
				} else {
					if summary, summaryErr := currentCertSummary(a.cfg.CertPath); summaryErr == nil && strings.TrimSpace(summary.Fingerprint) != "" && summary.Fingerprint != a.lastCertFP {
						a.log.Info("node certificate identity verified", "node_id", a.cfg.NodeID, "common_name", summary.CommonName, "not_after", summary.NotAfter.Format(time.RFC3339), "fingerprint_sha256", summary.Fingerprint)
						a.lastCertFP = summary.Fingerprint
					}
					return false, nil
				}
				if trustErr == nil && !trusted {
					// Continue into the enrollment flow below. This handles platform-control CA
					// rotation where bootstrap refreshes the CA bundle but leaves a locally valid
					// cert signed by the old node CA on disk.
				} else {
					return false, nil
				}
			} else {
				if summary, summaryErr := currentCertSummary(a.cfg.CertPath); summaryErr == nil && strings.TrimSpace(summary.Fingerprint) != "" && summary.Fingerprint != a.lastCertFP {
					a.log.Info("node certificate identity verified", "node_id", a.cfg.NodeID, "common_name", summary.CommonName, "not_after", summary.NotAfter.Format(time.RFC3339), "fingerprint_sha256", summary.Fingerprint)
					a.lastCertFP = summary.Fingerprint
				}
				return false, nil
			}
		}
		if err == nil && state.matches && state.expired {
			if a.cfg.EnrollmentToken == "" {
				return false, fmt.Errorf("%s: local cert for node_id=%s expired at %s; set GPUAAS_ENROLLMENT_TOKEN for recovery enrollment", certExpiredRecoveryTokenRequiredCode, a.cfg.NodeID, state.notAfter.Format(time.RFC3339))
			}
		}
		if a.cfg.EnrollmentToken == "" {
			if err != nil {
				return false, fmt.Errorf("local cert exists but is unreadable for node_id=%s; set GPUAAS_ENROLLMENT_TOKEN or rotate %s/%s: %w", a.cfg.NodeID, a.cfg.CertPath, a.cfg.KeyPath, err)
			}
			return false, fmt.Errorf("local cert identity mismatch for node_id=%s; set GPUAAS_ENROLLMENT_TOKEN or rotate %s/%s", a.cfg.NodeID, a.cfg.CertPath, a.cfg.KeyPath)
		}
		if err != nil {
			a.log.Warn("local certificate invalid; attempting re-enrollment", "node_id", a.cfg.NodeID, "cert_path", a.cfg.CertPath, "error", err)
		} else if state.matches {
			a.log.Warn("local certificate expired; attempting re-enrollment", "node_id", a.cfg.NodeID, "cert_path", a.cfg.CertPath, "expired_at", state.notAfter.Format(time.RFC3339))
		} else {
			a.log.Warn("local certificate identity mismatch; attempting re-enrollment", "node_id", a.cfg.NodeID, "cert_path", a.cfg.CertPath)
		}
	}
	if a.cfg.EnrollmentToken == "" {
		return false, fmt.Errorf("no cert/key found and GPUAAS_ENROLLMENT_TOKEN is empty")
	}
	return a.recoverEnroll(ctx, "startup_or_maintenance")
}

func (a *agent) tryRecoverRejectedIdentity(ctx context.Context, reason string, cause error) (bool, error) {
	if a.cfg.EnrollmentToken == "" {
		return false, nil
	}
	if suppressed, fingerprint, nextRetryAfter := a.recentRecoveredCertStillRejected(); suppressed {
		a.metrics.recoveryEnrollmentSuppressedTotal.Add(1)
		a.log.Warn("node identity still rejected after recovery enrollment; backing off before another enrollment attempt",
			"error", cause,
			"fingerprint_sha256", fingerprint,
			"next_retry_after", nextRetryAfter.String(),
		)
		return false, nil
	}
	a.log.Warn("node identity rejected by control plane; attempting recovery enrollment", "error", cause, "reason", strings.TrimSpace(reason))
	return a.recoverEnroll(ctx, reason)
}

func (a *agent) recoverEnroll(ctx context.Context, reason string) (enrolled bool, err error) {
	if a.cfg.EnrollmentToken == "" {
		return false, fmt.Errorf("GPUAAS_ENROLLMENT_TOKEN is empty")
	}
	a.metrics.recoveryEnrollmentAttemptsTotal.Add(1)
	defer func() {
		if err != nil {
			a.metrics.recoveryEnrollmentFailuresTotal.Add(1)
			return
		}
		if enrolled {
			a.metrics.recoveryEnrollmentSuccessTotal.Add(1)
		}
	}()
	priv, csrPEM, err := generateNodeCSR(a.cfg.NodeID)
	if err != nil {
		return false, fmt.Errorf("generate csr: %w", err)
	}
	payload := map[string]string{
		"node_id": a.cfg.NodeID,
		"csr":     csrPEM,
	}
	body, _ := json.Marshal(payload)
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.recoveryAPIURL()+"/internal/v1/nodes/enroll", bytes.NewReader(body))
	if err != nil {
		return false, err
	}
	req.Header.Set("Content-Type", "application/json")
	a.setInternalAuthHeaders(req, a.cfg.EnrollmentToken)
	a.logInternalRequest("node recovery enrollment request", req, "reason", strings.TrimSpace(reason))
	enrollClient := a.enrollmentHTTPClient()
	res, err := enrollClient.Do(req)
	if err != nil {
		return false, err
	}
	defer enrollClient.CloseIdleConnections()
	defer func() { _ = res.Body.Close() }()
	if res.StatusCode >= 400 {
		raw, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
		return false, fmt.Errorf("enroll failed: status=%d body=%s", res.StatusCode, string(raw))
	}
	var enrollResp struct {
		Certificate string `json:"certificate"`
		CABundle    string `json:"ca_bundle"`
	}
	if err := json.NewDecoder(res.Body).Decode(&enrollResp); err != nil {
		return false, fmt.Errorf("decode enroll response: %w", err)
	}
	if enrollResp.Certificate == "" {
		return false, fmt.Errorf("enroll response missing certificate")
	}

	keyPEM, err := privateKeyToPEM(priv)
	if err != nil {
		return false, fmt.Errorf("encode private key: %w", err)
	}
	certPEM := enrollResp.Certificate
	if err := writeSecureFile(a.cfg.KeyPath, keyPEM, 0o600); err != nil {
		return false, fmt.Errorf("write private key: %w", err)
	}
	if err := writeSecureFile(a.cfg.CertPath, certPEM, 0o644); err != nil {
		return false, fmt.Errorf("write certificate: %w", err)
	}
	if enrollResp.CABundle != "" {
		if err := writeSecureFile(a.cfg.NodeCertCABundlePath, enrollResp.CABundle, 0o644); err != nil {
			return false, fmt.Errorf("write ca bundle: %w", err)
		}
	}
	if summary, summaryErr := currentCertSummary(a.cfg.CertPath); summaryErr == nil {
		a.log.Info("node certificate enrolled", "node_id", a.cfg.NodeID, "common_name", summary.CommonName, "not_after", summary.NotAfter.Format(time.RFC3339), "fingerprint_sha256", summary.Fingerprint)
		a.lastCertFP = summary.Fingerprint
		a.lastRecoveryFP = summary.Fingerprint
		a.lastRecoveryAt = time.Now().UTC()
	}
	a.closeIdleConnections()
	return true, nil
}

func (a *agent) recentRecoveredCertStillRejected() (bool, string, time.Duration) {
	if strings.TrimSpace(a.lastRecoveryFP) == "" || a.lastRecoveryAt.IsZero() {
		return false, "", 0
	}
	elapsed := time.Since(a.lastRecoveryAt)
	if elapsed < 0 || elapsed >= credentialTaskPollBackoff {
		return false, "", 0
	}
	summary, err := currentCertSummary(a.cfg.CertPath)
	if err != nil || strings.TrimSpace(summary.Fingerprint) == "" {
		return false, "", 0
	}
	if summary.Fingerprint != a.lastRecoveryFP {
		return false, "", 0
	}
	return true, summary.Fingerprint, credentialTaskPollBackoff - elapsed
}

func loadRootCAs(caBundlePath string) (*x509.CertPool, error) {
	roots, err := x509.SystemCertPool()
	if err != nil || roots == nil {
		roots = x509.NewCertPool()
	}
	if strings.TrimSpace(caBundlePath) == "" || !fileExists(caBundlePath) {
		return roots, nil
	}
	pemBytes, err := os.ReadFile(caBundlePath)
	if err != nil {
		return nil, fmt.Errorf("read ca bundle: %w", err)
	}
	if ok := roots.AppendCertsFromPEM(pemBytes); !ok {
		return nil, fmt.Errorf("append ca bundle: no certificates parsed")
	}
	return roots, nil
}

func (a *agent) pollOnce(ctx context.Context) error {
	state, err := localCertNodeState(a.cfg.CertPath, a.cfg.NodeID, time.Now().UTC())
	if err == nil && state.matches && state.expired {
		return fmt.Errorf("%w at %s; recovery enrollment required before task polling", ErrLocalCertExpired, state.notAfter.UTC().Format(time.RFC3339))
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.taskWaitURL(), nil)
	if err != nil {
		return err
	}
	a.setInternalAuthHeaders(req, a.cfg.NodeAPIToken)
	res, err := a.httpClient.Do(req)
	if err != nil {
		return err
	}
	defer func() { _ = res.Body.Close() }()

	if res.StatusCode == http.StatusNoContent {
		return nil
	}
	a.logInternalRequestDebug("node task wait response", req, "status_code", res.StatusCode)
	if res.StatusCode != http.StatusOK {
		raw, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
		if isNodeIdentityRevokedResponse(res.StatusCode, raw) {
			return fmt.Errorf("%w: task wait rejected by control plane", ErrNodeIdentityRevoked)
		}
		return fmt.Errorf("task wait failed: status=%d body=%s", res.StatusCode, string(raw))
	}

	var task nodeTask
	if err := json.NewDecoder(res.Body).Decode(&task); err != nil {
		return fmt.Errorf("decode task: %w", err)
	}
	a.metrics.taskClaimsTotal.Add(1)
	a.log.Info("node task claimed",
		"task_id", task.TaskID,
		"task_type", task.TaskType,
		"node_id", task.NodeID,
		"correlation_id", task.CorrelationID,
		"expires_at", task.ExpiresAt,
		"task_params", summarizeTaskParamsForLog(task),
	)

	if err := a.validateTaskEnvelope(task); err != nil {
		a.log.Warn("node task envelope rejected",
			"task_id", task.TaskID,
			"task_type", task.TaskType,
			"node_id", task.NodeID,
			"correlation_id", task.CorrelationID,
			"error", err,
		)
		if strings.TrimSpace(task.TaskID) == "" {
			return fmt.Errorf("task envelope invalid: %w", err)
		}
		return a.postTaskResult(ctx, task.TaskID, nodeTaskResult{
			Status:      "rejected",
			Error:       fmt.Sprintf("task_envelope_invalid: %v", err),
			CompletedAt: time.Now().UTC().Format(time.RFC3339),
			Output:      map[string]any{},
		})
	}

	taskCtx, cancelTask := taskExecutionContext(ctx, task)
	defer cancelTask()
	dispatched := dispatchTask(taskCtx, task)
	result := nodeTaskResult{
		CompletedAt: time.Now().UTC().Format(time.RFC3339),
	}
	result.Status = dispatched.Status
	result.Error = dispatched.Error
	result.Output = dispatched.Output
	if result.Output == nil {
		result.Output = map[string]any{}
	}
	switch result.Status {
	case "success":
		a.log.Info("node task completed",
			"task_id", task.TaskID,
			"task_type", task.TaskType,
			"node_id", task.NodeID,
			"correlation_id", task.CorrelationID,
			"task_output", summarizeTaskOutputForLog(task.TaskType, result.Output),
		)
	case "failed", "rejected":
		a.log.Warn("node task failed",
			"task_id", task.TaskID,
			"task_type", task.TaskType,
			"node_id", task.NodeID,
			"correlation_id", task.CorrelationID,
			"error", result.Error,
			"task_output", summarizeTaskOutputForLog(task.TaskType, result.Output),
		)
	default:
		a.log.Info("node task finished",
			"task_id", task.TaskID,
			"task_type", task.TaskType,
			"node_id", task.NodeID,
			"correlation_id", task.CorrelationID,
			"status", result.Status,
			"task_output", summarizeTaskOutputForLog(task.TaskType, result.Output),
		)
	}
	return a.postTaskResultWithRetry(ctx, task.TaskID, result)
}

func taskExecutionContext(ctx context.Context, task nodeTask) (context.Context, context.CancelFunc) {
	expiresAt, err := time.Parse(time.RFC3339, strings.TrimSpace(task.ExpiresAt))
	if err != nil {
		return context.WithCancel(ctx)
	}
	return context.WithDeadline(ctx, expiresAt)
}

func (a *agent) maybeRenewCert(ctx context.Context) error {
	if !fileExists(a.cfg.CertPath) {
		return nil
	}
	certs, err := loadCertChain(a.cfg.CertPath)
	if err != nil {
		return err
	}
	if len(certs) == 0 {
		return fmt.Errorf("no certificate found in %s", a.cfg.CertPath)
	}
	cert := selectLeafCert(certs)
	if time.Now().UTC().After(cert.NotAfter) {
		return fmt.Errorf("%w at %s; renewal requires valid mTLS cert and must recover via enrollment token", ErrLocalCertExpired, cert.NotAfter.UTC().Format(time.RFC3339))
	}
	oldNotAfter := cert.NotAfter
	if time.Until(cert.NotAfter) > a.cfg.RenewBefore {
		return nil
	}

	priv, csrPEM, err := generateNodeCSR(a.cfg.NodeID)
	if err != nil {
		return fmt.Errorf("generate renew csr: %w", err)
	}

	reqBody, _ := json.Marshal(map[string]string{"csr": csrPEM})
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.cfg.APIURL+"/internal/v1/nodes/"+a.cfg.NodeID+"/cert/renew", bytes.NewReader(reqBody))
	if err != nil {
		return err
	}
	req.Header.Set("Content-Type", "application/json")
	a.setInternalAuthHeaders(req, a.cfg.NodeAPIToken)
	a.logInternalRequest("node certificate renew request", req)

	res, err := a.httpClient.Do(req)
	if err != nil {
		return err
	}
	defer func() { _ = res.Body.Close() }()
	if res.StatusCode >= 400 {
		raw, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
		if isNodeIdentityRevokedResponse(res.StatusCode, raw) {
			return fmt.Errorf("%w: renew rejected by control plane", ErrNodeIdentityRevoked)
		}
		return fmt.Errorf("renew failed: status=%d body=%s", res.StatusCode, string(raw))
	}
	var renewResp struct {
		Certificate string `json:"certificate"`
	}
	if err := json.NewDecoder(res.Body).Decode(&renewResp); err != nil {
		return fmt.Errorf("decode renew response: %w", err)
	}
	if strings.TrimSpace(renewResp.Certificate) == "" {
		return fmt.Errorf("renew response missing certificate")
	}

	keyPEM, err := privateKeyToPEM(priv)
	if err != nil {
		return fmt.Errorf("encode renew private key: %w", err)
	}
	if err := writeSecureFile(a.cfg.KeyPath, keyPEM, 0o600); err != nil {
		return fmt.Errorf("write renewed private key: %w", err)
	}
	if err := writeSecureFile(a.cfg.CertPath, renewResp.Certificate, 0o644); err != nil {
		return fmt.Errorf("write renewed cert: %w", err)
	}
	if summary, summaryErr := currentCertSummary(a.cfg.CertPath); summaryErr == nil {
		a.log.Info("node certificate renewed", "node_id", a.cfg.NodeID, "old_not_after", oldNotAfter.Format(time.RFC3339), "new_not_after", summary.NotAfter.Format(time.RFC3339), "fingerprint_sha256", summary.Fingerprint)
		a.lastCertFP = summary.Fingerprint
	}
	a.closeIdleConnections()
	return nil
}

func (a *agent) postTaskResult(ctx context.Context, taskID string, result nodeTaskResult) error {
	body, err := json.Marshal(result)
	if err != nil {
		return err
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, a.taskResultURL(taskID), bytes.NewReader(body))
	if err != nil {
		return err
	}
	req.Header.Set("Content-Type", "application/json")
	a.setInternalAuthHeaders(req, a.cfg.NodeAPIToken)
	a.logInternalRequest("node task result post", req,
		"task_id", taskID,
		"task_status", strings.TrimSpace(result.Status),
	)

	res, err := a.httpClient.Do(req)
	if err != nil {
		a.metrics.taskResultPostFailures.Add(1)
		return err
	}
	defer func() { _ = res.Body.Close() }()
	a.logInternalRequest("node task result response", req,
		"task_id", taskID,
		"task_status", strings.TrimSpace(result.Status),
		"status_code", res.StatusCode,
	)
	if res.StatusCode != http.StatusOK {
		a.metrics.taskResultPostFailures.Add(1)
		raw, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
		return fmt.Errorf("task result failed: status=%d body=%s", res.StatusCode, string(raw))
	}
	switch strings.TrimSpace(strings.ToLower(result.Status)) {
	case "success":
		a.metrics.taskCompletedTotal.Add(1)
	case "failed":
		a.metrics.taskFailedTotal.Add(1)
	case "rejected":
		a.metrics.taskRejectedTotal.Add(1)
	}
	return nil
}

func (a *agent) postTaskResultWithRetry(ctx context.Context, taskID string, result nodeTaskResult) error {
	resultCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 45*time.Second)
	defer cancel()

	backoff := time.Second
	var lastErr error
	for attempt := 1; attempt <= 6; attempt++ {
		err := a.postTaskResult(resultCtx, taskID, result)
		if err == nil {
			if attempt > 1 {
				a.log.Info("node task result post recovered",
					"task_id", taskID,
					"attempt", attempt,
					"task_status", strings.TrimSpace(result.Status),
				)
			}
			return nil
		}
		lastErr = err
		if resultCtx.Err() != nil {
			break
		}
		a.log.Warn("node task result post failed; retrying",
			"task_id", taskID,
			"attempt", attempt,
			"error", err,
			"next_retry_after", backoff.String(),
		)
		timer := time.NewTimer(backoff)
		select {
		case <-resultCtx.Done():
			timer.Stop()
			return fmt.Errorf("post task result retry timeout: %w", lastErr)
		case <-timer.C:
		}
		if backoff < 15*time.Second {
			backoff *= 2
		}
		if backoff > 15*time.Second {
			backoff = 15 * time.Second
		}
	}
	return fmt.Errorf("post task result failed after retries: %w", lastErr)
}

func (a *agent) taskWaitURL() string {
	return fmt.Sprintf("%s/internal/v1/nodes/%s/tasks/wait", a.cfg.APIURL, a.cfg.NodeID)
}

func (a *agent) taskResultURL(taskID string) string {
	return fmt.Sprintf("%s/internal/v1/nodes/%s/tasks/%s/result", a.cfg.APIURL, a.cfg.NodeID, taskID)
}

func (a *agent) setInternalAuthHeaders(req *http.Request, bearerToken string) {
	if strings.TrimSpace(bearerToken) != "" {
		req.Header.Set("Authorization", "Bearer "+bearerToken)
	}
	req.Header.Set("X-GPUAAS-Node-Instance-ID", strings.TrimSpace(a.cfg.InstanceID))
	req.Header.Set("X-GPUAAS-Agent-Version", strings.TrimSpace(buildinfo.Version))
	req.Header.Set("X-GPUAAS-Agent-Commit", strings.TrimSpace(buildinfo.Commit))
	req.Header.Set("X-GPUAAS-Agent-Built-At", strings.TrimSpace(buildinfo.BuiltAt))
	req.Header.Set("User-Agent", a.internalUserAgent())
}

func (a *agent) internalUserAgent() string {
	return "gpuaas-node-agent/" + strings.TrimSpace(buildinfo.Version) +
		" pid=" + strconv.Itoa(os.Getpid()) +
		" instance=" + strings.TrimSpace(a.cfg.InstanceID)
}

func (a *agent) logInternalRequest(message string, req *http.Request, extra ...any) {
	a.logInternalRequestAt(slog.LevelInfo, message, req, extra...)
}

func (a *agent) logInternalRequestDebug(message string, req *http.Request, extra ...any) {
	a.logInternalRequestAt(slog.LevelDebug, message, req, extra...)
}

func (a *agent) logInternalRequestAt(level slog.Level, message string, req *http.Request, extra ...any) {
	fields := []any{
		"node_id", a.cfg.NodeID,
		"node_instance_id", strings.TrimSpace(a.cfg.InstanceID),
		"pid", os.Getpid(),
		"method", req.Method,
		"path", req.URL.Path,
		"api_url", a.cfg.APIURL,
	}
	fields = append(fields, extra...)
	a.log.Log(context.Background(), level, message, fields...)
}

func (a *agent) closeIdleConnections() {
	if a.transport != nil {
		a.transport.CloseIdleConnections()
	}
}

func summarizeTaskParamsForLog(task nodeTask) map[string]any {
	params := task.Params
	switch strings.TrimSpace(task.TaskType) {
	case "allocation.provision_user":
		return map[string]any{
			"username_on_node": usernameOnNodeFromParams(params),
			"uid":              logIntParam(params, "uid"),
			"gid":              logIntParam(params, "gid"),
			"ssh_key_count":    logSliceLen(params, "ssh_public_keys"),
		}
	case "allocation.install_authorized_keys":
		return map[string]any{
			"allocation_id":    logStringParam(params, "allocation_id"),
			"username_on_node": usernameOnNodeFromParams(params),
			"ssh_key_count":    logSliceLen(params, "ssh_public_keys"),
		}
	case "allocation.revoke_user":
		return map[string]any{
			"username_on_node": usernameOnNodeFromParams(params),
			"reason":           logStringParam(params, "reason"),
		}
	case "runtime.write_env_file":
		return map[string]any{
			"env_file_path": logStringParam(params, "env_file_path"),
		}
	case "runtime.install_service_unit":
		return map[string]any{
			"systemd_unit_name": logStringParam(params, "systemd_unit_name"),
		}
	case "runtime.service_control":
		return map[string]any{
			"systemd_unit_name": logStringParam(params, "systemd_unit_name"),
			"action":            logStringParam(params, "action"),
		}
	case "terminal.open":
		return map[string]any{
			"allocation_id":  logStringParam(params, "allocation_id"),
			"username":       logStringParam(params, "username"),
			"capacity_shape": logStringParam(params, "capacity_shape"),
			"target_host":    logStringParam(params, "target_host"),
			"target_port":    logIntParam(params, "target_port"),
			"cols":           logIntParam(params, "cols"),
			"rows":           logIntParam(params, "rows"),
		}
	case "terminal.close":
		return map[string]any{
			"allocation_id": logStringParam(params, "allocation_id"),
			"reason":        logStringParam(params, "reason"),
		}
	case "node.uninstall":
		return map[string]any{
			"install_root":      logStringParam(params, "install_root"),
			"systemd_unit_name": logStringParam(params, "systemd_unit_name"),
			"reason":            logStringParam(params, "reason"),
		}
	case "node.self_update":
		return map[string]any{
			"expected_version":          logStringParam(params, "expected_version"),
			"package_sha256":            logStringParam(params, "package_sha256"),
			"install_root":              logStringParam(params, "install_root"),
			"systemd_unit_name":         logStringParam(params, "systemd_unit_name"),
			"install_container_runtime": logBoolParam(params, "install_container_runtime"),
			"container_runtime_package": logStringParam(params, "container_runtime_package"),
			"reason":                    logStringParam(params, "reason"),
		}
	case taskTypeWorkloadOCIRuntimeStatus:
		return map[string]any{}
	case taskTypeWorkloadOCILaunch:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"allocation_id":        logStringParam(params, "allocation_id"),
			"username":             logStringParam(params, "username"),
			"container_name":       logStringParam(params, "container_name"),
			"has_pull_credential":  logStringParam(params, "pull_credential_ref") != "",
		}
	case taskTypeWorkloadOCIControl:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"action":               logStringParam(params, "action"),
		}
	case taskTypeWorkloadOCIRemove:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"remove_scratch":       logBoolParam(params, "remove_scratch"),
		}
	case taskTypeWorkloadComposeLaunch:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"allocation_id":        logStringParam(params, "allocation_id"),
			"username":             logStringParam(params, "username"),
			"project_name":         logStringParam(params, "project_name"),
			"compose_yaml_bytes":   len(logStringParam(params, "compose_yaml")),
		}
	case taskTypeWorkloadComposeControl:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"action":               logStringParam(params, "action"),
		}
	case taskTypeWorkloadComposeRemove:
		return map[string]any{
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"remove_scratch":       logBoolParam(params, "remove_scratch"),
		}
	case taskTypeSliceVMProvision:
		return map[string]any{
			"allocation_id":  logStringParam(params, "allocation_id"),
			"vm_name":        logStringParam(params, "vm_name"),
			"image_sha256":   logStringParam(params, "image_sha256"),
			"has_artifact":   logStringParam(params, "image_artifact_ref") != "",
			"default_user":   logStringParam(params, "default_username"),
			"ssh_key_count":  logSliceLen(params, "ssh_public_keys"),
			"slot_count":     logSliceLen(params, "slots"),
			"ovs_bridge":     logStringParam(params, "ovs_bridge"),
			"has_image_path": logStringParam(params, "image_path") != "",
		}
	case taskTypeSliceVMRelease:
		return map[string]any{
			"allocation_id": logStringParam(params, "allocation_id"),
			"vm_name":       logStringParam(params, "vm_name"),
			"slot_count":    logSliceLen(params, "slots"),
			"wipe":          logBoolParam(params, "wipe"),
		}
	case taskTypeStorageMountAttach:
		return map[string]any{
			"attachment_id":        logStringParam(params, "attachment_id"),
			"allocation_id":        logStringParam(params, "allocation_id"),
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"provider_backend":     logStringParam(params, "provider_backend"),
			"mount_path":           logStringParam(params, "mount_path"),
			"access_mode":          logStringParam(params, "access_mode"),
			"write_policy":         logStringParam(params, "write_policy"),
		}
	case taskTypeStorageMountDetach:
		return map[string]any{
			"attachment_id":        logStringParam(params, "attachment_id"),
			"allocation_id":        logStringParam(params, "allocation_id"),
			"workload_instance_id": logStringParam(params, "workload_instance_id"),
			"mount_path":           logStringParam(params, "mount_path"),
			"force":                logBoolParam(params, "force"),
		}
	case taskTypeStorageMountProbe:
		return map[string]any{
			"attachment_id": logStringParam(params, "attachment_id"),
			"mount_path":    logStringParam(params, "mount_path"),
			"access_mode":   logStringParam(params, "access_mode"),
		}
	default:
		return map[string]any{}
	}
}

func summarizeTaskOutputForLog(taskType string, output map[string]any) map[string]any {
	switch strings.TrimSpace(taskType) {
	case "allocation.provision_user":
		return map[string]any{
			"applied":       output["applied"],
			"verified":      output["verified"],
			"username":      output["username"],
			"uid":           output["uid"],
			"gid":           output["gid"],
			"ssh_key_count": output["ssh_key_count"],
		}
	case "allocation.install_authorized_keys":
		return map[string]any{
			"applied":       output["applied"],
			"verified":      output["verified"],
			"username":      output["username"],
			"ssh_key_count": output["ssh_key_count"],
		}
	case "allocation.revoke_user":
		return map[string]any{
			"username":            output["username"],
			"user_existed_before": output["user_existed_before"],
			"home_existed_before": output["home_existed_before"],
			"user_exists_after":   output["user_exists_after"],
			"home_exists_after":   output["home_exists_after"],
		}
	case "terminal.open":
		return map[string]any{
			"accepted":      output["accepted"],
			"session_id":    output["session_id"],
			"allocation_id": output["allocation_id"],
			"username":      output["username"],
		}
	case taskTypeSliceVMProvision:
		return map[string]any{
			"vm_name":       output["vm_name"],
			"default_user":  output["default_user"],
			"private_ip":    output["private_ip"],
			"ssh_port":      output["ssh_port"],
			"slot_count":    output["slot_count"],
			"readiness":     output["readiness"],
			"raw_vnc":       output["raw_vnc"],
			"console_model": output["console_model"],
		}
	case taskTypeSliceVMRelease:
		return map[string]any{
			"vm_name":      output["vm_name"],
			"released":     output["released"],
			"hard_stopped": output["hard_stopped"],
			"wiped":        output["wiped"],
			"slot_count":   output["slot_count"],
		}
	case taskTypeStorageMountAttach, taskTypeStorageMountDetach, taskTypeStorageMountProbe:
		return map[string]any{
			"attachment_id": output["attachment_id"],
			"mount_path":    output["mount_path"],
			"provider":      output["provider_backend"],
			"access_mode":   output["access_mode"],
			"mounted":       output["mounted"],
			"detached":      output["detached"],
			"present":       output["present"],
		}
	default:
		return output
	}
}

func logStringParam(params map[string]any, key string) string {
	value, _ := params[key].(string)
	return strings.TrimSpace(value)
}

func logIntParam(params map[string]any, key string) any {
	value, err := intParam(params, key)
	if err != nil {
		return nil
	}
	return value
}

func logSliceLen(params map[string]any, key string) int {
	values, err := stringSliceParam(params, key)
	if err == nil {
		return len(values)
	}
	raw, ok := params[key]
	if !ok || raw == nil {
		return 0
	}
	if valuesAny, ok := raw.([]any); ok {
		return len(valuesAny)
	}
	return 0
}

func usernameOnNodeFromParams(params map[string]any) string {
	if value := logStringParam(params, "username_on_node"); value != "" {
		return value
	}
	return logStringParam(params, "username")
}

func fileExists(path string) bool {
	_, err := os.Stat(path)
	return err == nil
}

func isNodeIdentityRevokedResponse(statusCode int, rawBody []byte) bool {
	body := strings.ToLower(string(rawBody))
	switch statusCode {
	case http.StatusUnauthorized:
		return strings.Contains(body, "invalid node identity")
	case http.StatusNotFound:
		return strings.Contains(body, "node_not_found") || strings.Contains(body, "node not found")
	default:
		return false
	}
}

func (a *agent) validateTaskEnvelope(task nodeTask) error {
	if strings.TrimSpace(task.TaskID) == "" {
		return fmt.Errorf("missing task_id")
	}
	if strings.TrimSpace(task.TaskType) == "" {
		return fmt.Errorf("missing task_type")
	}
	if strings.TrimSpace(task.NodeID) == "" {
		return fmt.Errorf("missing node_id")
	}
	if task.NodeID != a.cfg.NodeID {
		return fmt.Errorf("node_id mismatch")
	}
	if strings.TrimSpace(task.ExpiresAt) == "" {
		return fmt.Errorf("missing expires_at")
	}
	expiresAt, err := time.Parse(time.RFC3339, task.ExpiresAt)
	if err != nil {
		return fmt.Errorf("invalid expires_at")
	}
	if time.Now().UTC().After(expiresAt) {
		return fmt.Errorf("task expired")
	}
	if strings.TrimSpace(task.Signature) == "" {
		return fmt.Errorf("missing signature")
	}
	paramsJSON, err := json.Marshal(task.Params)
	if err != nil {
		return fmt.Errorf("invalid params")
	}
	if err := verifyNodeTaskEnvelopeSignature(a.cfg, task, paramsJSON); err != nil {
		return err
	}
	return nil
}

func logBoolParam(params map[string]any, key string) any {
	value, ok := params[key].(bool)
	if !ok {
		return nil
	}
	return value
}

func verifyNodeTaskEnvelopeSignature(cfg Config, task nodeTask, params []byte) error {
	message := nodeTaskSignatureMessage(task.TaskType, task.NodeID, task.TaskID, params)

	switch {
	case strings.HasPrefix(task.Signature, "ed25519:"):
		rawSig := strings.TrimPrefix(task.Signature, "ed25519:")
		sigBytes, err := base64.RawURLEncoding.DecodeString(rawSig)
		if err != nil || len(sigBytes) != ed25519.SignatureSize {
			return fmt.Errorf("invalid ed25519 signature encoding")
		}
		pubKeys := taskSigningPublicKeys(cfg.TaskSigningPubKeys)
		if len(pubKeys) == 0 {
			return fmt.Errorf("missing or invalid task signing public key")
		}
		for _, pubKey := range pubKeys {
			if ed25519.Verify(pubKey, message, sigBytes) {
				return nil
			}
		}
		return fmt.Errorf("signature mismatch")
	default:
		rawSig := strings.TrimPrefix(task.Signature, "hmac:")
		key := []byte(cfg.NodeTaskSignKey)
		if len(key) == 0 {
			key = []byte("unsigned")
		}
		h := hmac.New(sha256.New, key)
		_, _ = h.Write(message)
		expected := base64.RawURLEncoding.EncodeToString(h.Sum(nil))
		if !hmac.Equal([]byte(rawSig), []byte(expected)) {
			return fmt.Errorf("signature mismatch")
		}
		return nil
	}
}

func taskSigningPublicKeys(raw string) []ed25519.PublicKey {
	fields := strings.FieldsFunc(strings.TrimSpace(raw), func(r rune) bool {
		return r == ',' || r == ';' || r == '\n' || r == '\r' || r == '\t' || r == ' '
	})
	out := make([]ed25519.PublicKey, 0, len(fields))
	for _, field := range fields {
		field = strings.TrimSpace(field)
		if field == "" {
			continue
		}
		pubBytes, err := base64.RawURLEncoding.DecodeString(field)
		if err != nil || len(pubBytes) != ed25519.PublicKeySize {
			continue
		}
		out = append(out, ed25519.PublicKey(pubBytes))
	}
	return out
}

func nodeTaskSignatureMessage(taskType, nodeID, taskID string, params []byte) []byte {
	return []byte(taskType + ":" + nodeID + ":" + taskID + ":" + string(params))
}

func generateNodeCSR(nodeID string) (ed25519.PrivateKey, string, error) {
	_, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		return nil, "", err
	}
	template := &x509.CertificateRequest{
		Subject: pkix.Name{
			CommonName:         "node-" + nodeID,
			Organization:       []string{"gpuaas"},
			OrganizationalUnit: []string{"node-agent"},
		},
		DNSNames: []string{"node-" + nodeID + ".internal.gpuaas.io"},
	}
	csrDER, err := x509.CreateCertificateRequest(rand.Reader, template, priv)
	if err != nil {
		return nil, "", err
	}
	csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrDER})
	return priv, string(csrPEM), nil
}

func privateKeyToPEM(key ed25519.PrivateKey) (string, error) {
	der, err := x509.MarshalPKCS8PrivateKey(key)
	if err != nil {
		return "", err
	}
	return string(pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der})), nil
}

func writeSecureFile(path, content string, perm os.FileMode) error {
	if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
		return err
	}
	return os.WriteFile(path, []byte(content), perm)
}

func loadCertChain(path string) ([]*x509.Certificate, error) {
	raw, err := os.ReadFile(path)
	if err != nil {
		return nil, err
	}
	certs := make([]*x509.Certificate, 0, 2)
	rest := raw
	for len(rest) != 0 {
		block, next := pem.Decode(rest)
		if block == nil {
			break
		}
		rest = next
		if block.Type != "CERTIFICATE" {
			continue
		}
		cert, err := x509.ParseCertificate(block.Bytes)
		if err != nil {
			return nil, fmt.Errorf("parse certificate: %w", err)
		}
		certs = append(certs, cert)
	}
	return certs, nil
}

func selectLeafCert(certs []*x509.Certificate) *x509.Certificate {
	if len(certs) == 0 {
		return nil
	}
	for _, cert := range certs {
		if cert != nil && !cert.IsCA {
			return cert
		}
	}
	return certs[0]
}

type certSummary struct {
	CommonName  string
	NotAfter    time.Time
	Fingerprint string
}

func currentCertSummary(path string) (certSummary, error) {
	certs, err := loadCertChain(path)
	if err != nil {
		return certSummary{}, err
	}
	if len(certs) == 0 {
		return certSummary{}, fmt.Errorf("no certificate found")
	}
	cert := selectLeafCert(certs)
	sum := sha256.Sum256(cert.Raw)
	return certSummary{
		CommonName:  strings.TrimSpace(cert.Subject.CommonName),
		NotAfter:    cert.NotAfter,
		Fingerprint: fmt.Sprintf("%x", sum[:]),
	}, nil
}

func localCertMatchesNodeID(certPath, nodeID string) (bool, error) {
	state, err := localCertNodeState(certPath, nodeID, time.Now().UTC())
	if err != nil {
		return false, err
	}
	return state.matches, nil
}

func localCertChainsToBundle(certPath, caBundlePath string, now time.Time) (bool, error) {
	certs, err := loadCertChain(certPath)
	if err != nil {
		return false, err
	}
	if len(certs) == 0 {
		return false, fmt.Errorf("no certificate found")
	}
	leaf := selectLeafCert(certs)
	if leaf == nil {
		return false, fmt.Errorf("no leaf certificate found")
	}
	rawCA, err := os.ReadFile(caBundlePath)
	if err != nil {
		return false, err
	}
	roots := x509.NewCertPool()
	if ok := roots.AppendCertsFromPEM(rawCA); !ok {
		return false, fmt.Errorf("append node cert ca bundle: no certificates parsed")
	}
	intermediates := x509.NewCertPool()
	for _, cert := range certs {
		if cert == nil || cert == leaf {
			continue
		}
		intermediates.AddCert(cert)
	}
	_, err = leaf.Verify(x509.VerifyOptions{
		Roots:         roots,
		Intermediates: intermediates,
		CurrentTime:   now,
		KeyUsages:     []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
	})
	if err != nil {
		return false, nil
	}
	return true, nil
}

type certNodeState struct {
	matches  bool
	expired  bool
	notAfter time.Time
}

func localCertNodeState(certPath, nodeID string, now time.Time) (certNodeState, error) {
	certs, err := loadCertChain(certPath)
	if err != nil {
		return certNodeState{}, err
	}
	if len(certs) == 0 {
		return certNodeState{}, fmt.Errorf("no certificate found")
	}
	expectedCN := "node-" + nodeID
	expectedDNS := expectedCN + ".internal.gpuaas.io"
	leaf := selectLeafCert(certs)
	state := certNodeState{
		matches:  false,
		expired:  now.After(leaf.NotAfter),
		notAfter: leaf.NotAfter,
	}
	for _, cert := range certs {
		if strings.TrimSpace(cert.Subject.CommonName) == expectedCN {
			state.matches = true
			return state, nil
		}
		for _, dns := range cert.DNSNames {
			if strings.TrimSpace(dns) == expectedDNS {
				state.matches = true
				return state, nil
			}
		}
	}
	return state, nil
}
