import { useMemo } from 'react';
import { Node, Edge } from 'reactflow';
import Dagre from '@dagrejs/dagre';

type NodeData = {
  expanded: boolean;
  expandable: boolean;
};

type UseExpandCollapseOptions = {
  layoutNodes?: boolean;
};

function filterCollapsedChildren(dagre: Dagre.graphlib.Graph, node: Node<NodeData>) {
  // Ensure dagre is valid before proceeding
  if (!dagre || !node) return;

  try {
    // Get successors with null check
    const successors = dagre.successors(node.id);
    const children = successors ? (successors as unknown as string[]) : [];

    // Update expandable state based on valid children
    node.data.expandable = children.length > 0;

    // Handle collapsed state
    if (!node.data.expanded) {
      const nodesToProcess = [...children];
      while (nodesToProcess.length > 0) {
        const child = nodesToProcess.pop();
        if (child && dagre.hasNode(child)) {
          const childSuccessors = dagre.successors(child);
          if (childSuccessors) {
            nodesToProcess.push(...(childSuccessors as unknown as string[]));
          }
          dagre.removeNode(child);
        }
      }
    }
  } catch (error) {
    console.warn('Error processing node:', node.id, error);
  }
}

function useExpandCollapse(
  nodes: Node[],
  edges: Edge[],
  // showData: boolean,
  { layoutNodes = true }: UseExpandCollapseOptions = {},
): { nodes: Node[]; edges: Edge[] } {
  return useMemo(() => {
    if (!layoutNodes) return { nodes, edges };
    // 1. Create a new instance of `Dagre.graphlib.Graph` and set some default
    // properties.
    const dagre = new Dagre.graphlib.Graph().setDefaultEdgeLabel(() => ({})).setGraph({ rankdir: 'TB' });

    // 2. Add each node and edge to the dagre graph. Instead of using each node's
    // intrinsic width and height, we tell dagre to use the `treeWidth` and
    // `treeHeight` values. This lets you control the space between nodes.
    for (const node of nodes) {
      let height = 350;
      let width = 350;

      if (node.data.dataExpanded) {
        height = 550;
      }
      if (node.type === 'plan') {
        height = 100;
      }

      dagre.setNode(node.id, {
        width,
        height,
        data: node.data,
      });
    }

    for (const edge of edges) {
      dagre.setEdge(edge.source, edge.target);
    }

    // 3. Iterate over the nodes *again* to determine which ones should be hidden
    // based on expand/collapse state. Hidden nodes are removed from the dagre
    // graph entirely.
    if (!dagre) {
      return { nodes, edges };
    }
    for (const node of nodes) {
      filterCollapsedChildren(dagre, node);
    }

    Dagre.layout(dagre, { disableOptimalOrderHeuristic: true });

    // // 4. Run the dagre layouting algorithm.
    // Dagre.layout(dagre, { disableOptimalOrderHeuristic: true });

    return {
      // 5. Return a new array of layouted nodes. This will not include any nodes
      // that were removed from the dagre graph in step 3.
      //
      // 💡 `Array.flatMap` can act as a *filter map*. If we want to remove an
      // element from the array, we can return an empty array in this iteration.
      // Otherwise, we can map the element like normal and wrap it in a singleton
      // array.
      nodes: nodes.flatMap((node) => {
        // This node might have been filtered out by `filterCollapsedChildren` if
        // any of its ancestors were collpased.
        if (!dagre.hasNode(node.id)) return [];

        // if it has absolute posiiton (i.e. from dragging), use that
        const { x, y } = dagre.node(node.id);

        const nodeWithPosition = dagre.node(node.id);

        let position = { x, y };

        if (node.width && node.height) {
          position = {
            x: nodeWithPosition.x - node.width / 2,
            y: nodeWithPosition.y - nodeWithPosition.height / 2,
          };
        }

        // Find parent node's expanded state
        const parentNode = nodes.find((n) => n.id === node.data.parentId);
        const isVisible = !node.data.parentId || (parentNode && parentNode.data.expanded);

        return [
          {
            ...node,
            position,
            data: { ...node.data },
            height: nodeWithPosition.height,
            width: nodeWithPosition.width,
            hidden: !isVisible,
          },
        ];
      }),
      edges: edges.map((edge) => {
        // Hide edges connected to hidden nodes
        const sourceNode = nodes.find((n) => n.id === edge.source);
        const targetNode = nodes.find((n) => n.id === edge.target);
        const isVisible = sourceNode && targetNode && !sourceNode.hidden && !targetNode.hidden;

        return {
          ...edge,
          hidden: !isVisible,
        };
      }),
    };
  }, [nodes, edges, layoutNodes]);
}

export default useExpandCollapse;
