package audit

import (
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"strings"
	"time"
)

const (
	AuditCanonicalizationV1 = "audit-record-canonical/v1"
	AuditDigestSHA256       = "sha256"

	VerificationPass    = "pass"
	VerificationFail    = "fail"
	VerificationPartial = "partial"
	VerificationBlocked = "blocked"
)

var (
	ErrAuditIntegrityInvalidInput = errors.New("audit integrity input invalid")
	ErrAuditIntegrityMismatch     = errors.New("audit integrity mismatch")
)

type RecordDigest struct {
	AuditID                 string    `json:"audit_id"`
	DigestAlgorithm         string    `json:"digest_algorithm"`
	DigestHex               string    `json:"digest_hex"`
	CanonicalizationVersion string    `json:"canonicalization_version"`
	OccurredAt              time.Time `json:"occurred_at"`
}

type BatchManifest struct {
	BatchID                 string         `json:"batch_id"`
	EnvironmentProfile      string         `json:"environment_profile"`
	SequenceNumber          int64          `json:"sequence_number"`
	FirstAuditID            string         `json:"first_audit_id"`
	LastAuditID             string         `json:"last_audit_id"`
	FirstOccurredAt         time.Time      `json:"first_occurred_at"`
	LastOccurredAt          time.Time      `json:"last_occurred_at"`
	RecordCount             int            `json:"record_count"`
	CanonicalizationVersion string         `json:"canonicalization_version"`
	RecordDigestAlgorithm   string         `json:"record_digest_algorithm"`
	RecordDigests           []RecordDigest `json:"record_digests"`
	PreviousBatchHash       *string        `json:"previous_batch_hash,omitempty"`
	BatchRootHash           string         `json:"batch_root_hash"`
	VerificationStatus      string         `json:"verification_status"`
	VerifierMessage         string         `json:"verifier_message,omitempty"`
	CreatedAt               time.Time      `json:"created_at"`
}

type BatchInput struct {
	BatchID            string
	EnvironmentProfile string
	SequenceNumber     int64
	PreviousBatchHash  *string
	Events             []Event
	CreatedAt          time.Time
}

type IntegrityStatus struct {
	EnvironmentProfile string     `json:"environment_profile"`
	LatestBatchID      string     `json:"latest_batch_id,omitempty"`
	LatestSequence     int64      `json:"latest_sequence,omitempty"`
	LatestCreatedAt    *time.Time `json:"latest_created_at,omitempty"`
	FreshnessStatus    string     `json:"freshness_status"`
	ContinuityStatus   string     `json:"continuity_status"`
	VerificationStatus string     `json:"verification_status"`
	Message            string     `json:"message,omitempty"`
}

func DigestEvent(event Event) (RecordDigest, error) {
	if strings.TrimSpace(event.ID) == "" || event.OccurredAt.IsZero() {
		return RecordDigest{}, ErrAuditIntegrityInvalidInput
	}
	canonical := canonicalAuditRecord{
		ID:                    strings.TrimSpace(event.ID),
		OrgID:                 stringPtrValue(event.OrgID),
		ActorUserID:           stringPtrValue(event.ActorUserID),
		ActorServiceAccountID: stringPtrValue(event.ActorServiceAccountID),
		ActorRole:             strings.TrimSpace(event.ActorRole),
		Action:                strings.TrimSpace(event.Action),
		TargetType:            strings.TrimSpace(event.TargetType),
		TargetID:              stringPtrValue(event.TargetID),
		Result:                strings.TrimSpace(event.Result),
		CorrelationID:         strings.TrimSpace(event.CorrelationID),
		Metadata:              canonicalMetadata(event.Metadata),
		OccurredAt:            normalizeAuditIntegrityTime(event.OccurredAt),
	}
	if canonical.ActorRole == "" || canonical.Action == "" || canonical.TargetType == "" || canonical.Result == "" || canonical.CorrelationID == "" {
		return RecordDigest{}, ErrAuditIntegrityInvalidInput
	}
	digest, err := hashCanonicalJSON(canonical)
	if err != nil {
		return RecordDigest{}, err
	}
	return RecordDigest{
		AuditID:                 canonical.ID,
		DigestAlgorithm:         AuditDigestSHA256,
		DigestHex:               digest,
		CanonicalizationVersion: AuditCanonicalizationV1,
		OccurredAt:              canonical.OccurredAt,
	}, nil
}

func BuildBatchManifest(input BatchInput) (BatchManifest, error) {
	if strings.TrimSpace(input.BatchID) == "" || strings.TrimSpace(input.EnvironmentProfile) == "" || input.SequenceNumber <= 0 || len(input.Events) == 0 {
		return BatchManifest{}, ErrAuditIntegrityInvalidInput
	}
	createdAt := normalizeAuditIntegrityTime(input.CreatedAt)
	if createdAt.IsZero() {
		createdAt = normalizeAuditIntegrityTime(time.Now())
	}
	digests := make([]RecordDigest, 0, len(input.Events))
	for _, event := range input.Events {
		digest, err := DigestEvent(event)
		if err != nil {
			return BatchManifest{}, err
		}
		digests = append(digests, digest)
	}
	manifest := BatchManifest{
		BatchID:                 strings.TrimSpace(input.BatchID),
		EnvironmentProfile:      strings.TrimSpace(input.EnvironmentProfile),
		SequenceNumber:          input.SequenceNumber,
		FirstAuditID:            digests[0].AuditID,
		LastAuditID:             digests[len(digests)-1].AuditID,
		FirstOccurredAt:         digests[0].OccurredAt,
		LastOccurredAt:          digests[len(digests)-1].OccurredAt,
		RecordCount:             len(digests),
		CanonicalizationVersion: AuditCanonicalizationV1,
		RecordDigestAlgorithm:   AuditDigestSHA256,
		RecordDigests:           digests,
		PreviousBatchHash:       trimHashPtr(input.PreviousBatchHash),
		VerificationStatus:      VerificationPartial,
		CreatedAt:               createdAt,
	}
	root, err := batchRootHash(manifest)
	if err != nil {
		return BatchManifest{}, err
	}
	manifest.BatchRootHash = root
	manifest.VerificationStatus = VerificationPass
	manifest.VerifierMessage = "batch root verified"
	return manifest, nil
}

func VerifyBatchManifest(manifest BatchManifest) error {
	if strings.TrimSpace(manifest.BatchID) == "" || strings.TrimSpace(manifest.EnvironmentProfile) == "" || manifest.SequenceNumber <= 0 || len(manifest.RecordDigests) == 0 {
		return ErrAuditIntegrityInvalidInput
	}
	if manifest.RecordCount != len(manifest.RecordDigests) {
		return fmt.Errorf("%w: record_count=%d digest_count=%d", ErrAuditIntegrityMismatch, manifest.RecordCount, len(manifest.RecordDigests))
	}
	if manifest.FirstAuditID != manifest.RecordDigests[0].AuditID || manifest.LastAuditID != manifest.RecordDigests[len(manifest.RecordDigests)-1].AuditID {
		return fmt.Errorf("%w: audit id range does not match record digests", ErrAuditIntegrityMismatch)
	}
	root, err := batchRootHash(manifest)
	if err != nil {
		return err
	}
	if root != strings.TrimSpace(manifest.BatchRootHash) {
		return fmt.Errorf("%w: root hash mismatch", ErrAuditIntegrityMismatch)
	}
	return nil
}

func VerifyBatchChain(manifests []BatchManifest) error {
	var previous *BatchManifest
	for i := range manifests {
		manifest := manifests[i]
		if err := VerifyBatchManifest(manifest); err != nil {
			return err
		}
		if previous != nil {
			if manifest.SequenceNumber != previous.SequenceNumber+1 {
				return fmt.Errorf("%w: sequence gap previous=%d current=%d", ErrAuditIntegrityMismatch, previous.SequenceNumber, manifest.SequenceNumber)
			}
			if manifest.PreviousBatchHash == nil || *manifest.PreviousBatchHash != previous.BatchRootHash {
				return fmt.Errorf("%w: previous hash mismatch sequence=%d", ErrAuditIntegrityMismatch, manifest.SequenceNumber)
			}
		}
		previous = &manifest
	}
	return nil
}

func BuildIntegrityStatus(environmentProfile string, manifests []BatchManifest, now time.Time, freshnessThreshold time.Duration) IntegrityStatus {
	status := IntegrityStatus{
		EnvironmentProfile: strings.TrimSpace(environmentProfile),
		FreshnessStatus:    VerificationBlocked,
		ContinuityStatus:   VerificationBlocked,
		VerificationStatus: VerificationBlocked,
		Message:            "no audit batch manifests available",
	}
	if len(manifests) == 0 {
		return status
	}
	latest := manifests[len(manifests)-1]
	latestCreatedAt := normalizeAuditIntegrityTime(latest.CreatedAt)
	status.LatestBatchID = latest.BatchID
	status.LatestSequence = latest.SequenceNumber
	status.LatestCreatedAt = &latestCreatedAt
	status.FreshnessStatus = VerificationPass
	status.ContinuityStatus = VerificationPass
	status.VerificationStatus = VerificationPass
	status.Message = "audit batch hash chain verified"
	if freshnessThreshold > 0 && !latestCreatedAt.IsZero() {
		now = normalizeAuditIntegrityTime(now)
		if now.IsZero() {
			now = normalizeAuditIntegrityTime(time.Now())
		}
		if now.Sub(latestCreatedAt) > freshnessThreshold {
			status.FreshnessStatus = VerificationFail
			status.VerificationStatus = VerificationFail
			status.Message = "latest audit batch is stale"
		}
	}
	if err := VerifyBatchChain(manifests); err != nil {
		status.ContinuityStatus = VerificationFail
		status.VerificationStatus = VerificationFail
		status.Message = err.Error()
	}
	return status
}

type canonicalAuditRecord struct {
	ID                    string         `json:"id"`
	OrgID                 string         `json:"org_id"`
	ActorUserID           string         `json:"actor_user_id"`
	ActorServiceAccountID string         `json:"actor_service_account_id"`
	ActorRole             string         `json:"actor_role"`
	Action                string         `json:"action"`
	TargetType            string         `json:"target_type"`
	TargetID              string         `json:"target_id"`
	Result                string         `json:"result"`
	CorrelationID         string         `json:"correlation_id"`
	Metadata              map[string]any `json:"metadata"`
	OccurredAt            time.Time      `json:"occurred_at"`
}

type canonicalBatchRoot struct {
	BatchID                 string         `json:"batch_id"`
	EnvironmentProfile      string         `json:"environment_profile"`
	SequenceNumber          int64          `json:"sequence_number"`
	FirstAuditID            string         `json:"first_audit_id"`
	LastAuditID             string         `json:"last_audit_id"`
	FirstOccurredAt         time.Time      `json:"first_occurred_at"`
	LastOccurredAt          time.Time      `json:"last_occurred_at"`
	RecordCount             int            `json:"record_count"`
	CanonicalizationVersion string         `json:"canonicalization_version"`
	RecordDigestAlgorithm   string         `json:"record_digest_algorithm"`
	RecordDigests           []RecordDigest `json:"record_digests"`
	PreviousBatchHash       string         `json:"previous_batch_hash"`
}

func batchRootHash(manifest BatchManifest) (string, error) {
	previousHash := ""
	if manifest.PreviousBatchHash != nil {
		previousHash = strings.TrimSpace(*manifest.PreviousBatchHash)
	}
	root := canonicalBatchRoot{
		BatchID:                 strings.TrimSpace(manifest.BatchID),
		EnvironmentProfile:      strings.TrimSpace(manifest.EnvironmentProfile),
		SequenceNumber:          manifest.SequenceNumber,
		FirstAuditID:            strings.TrimSpace(manifest.FirstAuditID),
		LastAuditID:             strings.TrimSpace(manifest.LastAuditID),
		FirstOccurredAt:         normalizeAuditIntegrityTime(manifest.FirstOccurredAt),
		LastOccurredAt:          normalizeAuditIntegrityTime(manifest.LastOccurredAt),
		RecordCount:             manifest.RecordCount,
		CanonicalizationVersion: strings.TrimSpace(manifest.CanonicalizationVersion),
		RecordDigestAlgorithm:   strings.TrimSpace(manifest.RecordDigestAlgorithm),
		RecordDigests:           manifest.RecordDigests,
		PreviousBatchHash:       previousHash,
	}
	return hashCanonicalJSON(root)
}

func hashCanonicalJSON(value any) (string, error) {
	data, err := json.Marshal(value)
	if err != nil {
		return "", err
	}
	sum := sha256.Sum256(data)
	return hex.EncodeToString(sum[:]), nil
}

func canonicalMetadata(metadata map[string]any) map[string]any {
	if metadata == nil {
		return map[string]any{}
	}
	data, err := json.Marshal(metadata)
	if err != nil {
		return map[string]any{"_metadata_canonicalization_error": err.Error()}
	}
	var out map[string]any
	if err := json.Unmarshal(data, &out); err != nil {
		return map[string]any{"_metadata_canonicalization_error": err.Error()}
	}
	if out == nil {
		return map[string]any{}
	}
	return out
}

func normalizeAuditIntegrityTime(value time.Time) time.Time {
	if value.IsZero() {
		return time.Time{}
	}
	return value.UTC().Round(0)
}

func trimHashPtr(value *string) *string {
	if value == nil {
		return nil
	}
	trimmed := strings.TrimSpace(*value)
	if trimmed == "" {
		return nil
	}
	return &trimmed
}
