import { useEffect, useRef, useState } from 'react';
import {
  useSnapNavigation,
  type UseSnapNavigationResults,
} from './snap-navigation';
import { useSnapTargets } from './snap-utils';

const FRAME_INTERSECTION_THRESHOLD = 0.9;

export interface UseReelResults extends UseSnapNavigationResults {
  /**
   * The index of the first visible frame.
   */
  activeIndex: number;

  /**
   * The number of indicators to show, matching the number of snap points in the reel.
   */
  indicatorCount: number;

  /**
   * The number of currently visible frames.
   */
  visibleCount: number;

  /**
   * The total number of frames in the reel.
   */
  totalCount: number;

  reelProps: {
    /**
     * The tab index of the reel, if the reel has no focusable children.
     */
    tabIndex?: 0;
  };
}

export interface UseReelOptions {
  /**
   * The ref to the container element.
   */
  ref:
    | React.RefObject<HTMLElement | null>
    | React.MutableRefObject<HTMLElement | undefined>;

  /**
   * Triggered when the active index (the first visible frame) changes.
   */
  onChange?: ({
    activeIndex,
    currentTarget,
    target,
  }: {
    /**
     * The index of the first visible frame.
     */
    activeIndex: number;

    /**
     * The first visible frame.
     */
    target: HTMLElement | null;

    /**
     * The container element (the reel itself).
     */
    currentTarget: HTMLElement;
  }) => void;
}

export function useReel({ ref, onChange }: UseReelOptions): UseReelResults {
  // Default to 0 for SSR
  const [activeIndex, setActiveIndex] = useState(0);
  const prevActiveIndex = useRef(activeIndex);
  const [visibleFrames, setVisibleFrames] = useState<Set<HTMLElement>>(
    new Set()
  );
  const { frames, maxVisibleFrames, hasFocusableChildren } = useReelFrames(ref);
  const hasFrames = frames && frames.length > 0;
  let indicatorCount =
    hasFrames && maxVisibleFrames > 0
      ? frames.length - maxVisibleFrames + 1
      : 0;
  if (indicatorCount < 2) {
    indicatorCount = 0;
  }

  const snapNavigation = useSnapNavigation({ ref });

  useEffect(() => {
    if (
      !ref.current ||
      !frames ||
      // We intentionally only update the active index if there are visible frames.
      // While you're scrolling there might be 0 fully visible frames but we want the
      // indicator to "lag" at the previous position until the scroll is complete.
      (visibleFrames.size === 0 && frames.length > 0)
    ) {
      return;
    }

    let nextActiveIndex = -1;

    if (visibleFrames.size > 0) {
      const visibleIndices = [];
      for (const visibleFrame of visibleFrames) {
        visibleIndices.push(frames.indexOf(visibleFrame));
      }
      nextActiveIndex = Math.min(indicatorCount - 1, ...visibleIndices);
    }

    setActiveIndex(nextActiveIndex);
    if (onChange && nextActiveIndex !== prevActiveIndex.current) {
      prevActiveIndex.current = nextActiveIndex;
      onChange({
        activeIndex: nextActiveIndex,
        target: frames[nextActiveIndex] || null,
        currentTarget: ref.current,
      });
    }
  }, [frames, visibleFrames, indicatorCount, onChange, ref]);

  useEffect(() => {
    setVisibleFrames(new Set());
    if (!frames || frames.length === 0 || !ref.current) return;
    const rootMargin = window.getComputedStyle(ref.current).paddingInline;
    const intersectionObserver = new IntersectionObserver(
      (list) => {
        const removedNodes: HTMLElement[] = [];
        const addedNodes: HTMLElement[] = [];
        for (const item of list) {
          if (item.isIntersecting) {
            addedNodes.push(item.target as HTMLElement);
          } else {
            removedNodes.push(item.target as HTMLElement);
          }
        }
        setVisibleFrames((prev) => {
          if (!prev) prev = new Set();
          for (const node of removedNodes) {
            prev.delete(node);
          }
          for (const node of addedNodes) {
            prev.add(node);
          }
          return new Set(prev);
        });
      },
      { root: ref.current, threshold: FRAME_INTERSECTION_THRESHOLD, rootMargin }
    );
    for (const frame of frames) {
      intersectionObserver.observe(frame);
    }
    return () => {
      intersectionObserver.disconnect();
    };
  }, [frames, ref]);

  return {
    ...snapNavigation,
    reelProps: {
      tabIndex: hasFocusableChildren ? undefined : 0,
    },
    activeIndex,
    visibleCount: visibleFrames ? visibleFrames.size : 0,
    totalCount: frames ? frames.length : 0,
    indicatorCount,
  };
}

/**
 * Keeps track of the snap targets in the container.
 */
function useReelFrames(
  ref:
    | React.RefObject<HTMLElement | null>
    | React.MutableRefObject<HTMLElement | undefined>
) {
  const frames = useSnapTargets(ref);
  const [maxVisibleFrames, setMaxVisibleFrames] = useState(0);
  const [hasFocusableChildren, setHasFocusableChildren] = useState(false);

  useEffect(() => {
    const container = ref.current;
    if (!container || !frames || frames.length === 0) return;
    setHasFocusableChildren(elementHasFocusableChildren(container));

    const resizeObserver = new ResizeObserver((list) => {
      for (const entry of list) {
        const container = entry.target as HTMLElement;
        const item = frames[0];
        const gapStyle = parseInt(
          window.getComputedStyle(container).columnGap,
          10
        );
        const gap = Number.isNaN(gapStyle) ? 0 : gapStyle;
        const itemWidth = item.offsetWidth + gap;
        const containerWidth = container.offsetWidth + gap;
        const remainingWidth = containerWidth % itemWidth;
        let maxVisibleFrames = Math.floor(containerWidth / itemWidth);
        if (remainingWidth > item.offsetWidth * FRAME_INTERSECTION_THRESHOLD) {
          maxVisibleFrames++;
        }
        setMaxVisibleFrames(maxVisibleFrames);
      }
    });
    resizeObserver.observe(container);
    return () => {
      resizeObserver.disconnect();
    };
  }, [ref, frames]);
  return { frames, maxVisibleFrames, hasFocusableChildren };
}

function elementHasFocusableChildren(element: HTMLElement) {
  const treeWalker = document.createTreeWalker(
    element,
    NodeFilter.SHOW_ELEMENT,
    {
      acceptNode(node: HTMLElement) {
        return isFocusable(node)
          ? NodeFilter.FILTER_ACCEPT
          : NodeFilter.FILTER_SKIP;
      },
    }
  );
  return !!treeWalker.nextNode();
}

/**
 * The `tabIndex` property does not directly access the `tabindex` attribute
 * but gives you the computed "tab index" which works even for elements which
 * are focusable by default.
 */
function isFocusable(node: Node) {
  return node.nodeType === 1 && (node as HTMLElement).tabIndex >= 0;
}
