import dagre from '@dagrejs/dagre';
import { Edge, Node, Position } from 'reactflow';
import { convertRemToPixel } from '../Helpers/convertRemToPixel';
import { nodeWidth } from './constants';
import { nodesWithMultipleEdges } from './nodesWithMultipleEdges';

const estimateNodeHeight = (node: Node): number => {
  const baseHeight = 35;
  const title = node.data?.value?.name || '';

  const charsPerLine = 25;
  const lineHeight = 14;

  const textHeight = Math.ceil(title.length / charsPerLine) * lineHeight;
  const height = baseHeight + textHeight;
  return height;
};

const groupNodesByParent = (nodes: Node[], edges: Edge[]) => {
  const parentMap: Record<string, Node[]> = {};

  edges.forEach((edge) => {
    if (!parentMap[edge.source]) {
      parentMap[edge.source] = [];
    }
    const targetNode = nodes.find((node) => node.id === edge.target);
    if (targetNode) {
      parentMap[edge.source].push(targetNode);
    }
  });

  return parentMap;
};

export const getLayoutedElements = (nodes: Node[], edges: Edge[]) => {
  const dagreGraph = new dagre.graphlib.Graph();
  dagreGraph.setDefaultEdgeLabel(() => ({}));

  dagreGraph.setGraph({ rankdir: 'TB' });

  const width = convertRemToPixel(nodeWidth);
  const branchNodes = nodesWithMultipleEdges(nodes, edges);

  nodes.forEach((node) => {
    const height = estimateNodeHeight(node);
    dagreGraph.setNode(node.id, {
      width: width,
      height: height,
    });
  });

  const branchEdgeHeight = 48;
  const labelEdgeHeight = 24;

  edges.forEach((edge) => {
    const edgeOptions = {
      height: branchNodes.includes(edge.source)
        ? branchEdgeHeight
        : labelEdgeHeight,
    };

    dagreGraph.setEdge(edge.source, edge.target, edgeOptions);
  });

  dagre.layout(dagreGraph);

  const parentMap = groupNodesByParent(nodes, edges);

  nodes.forEach((node) => {
    const height = estimateNodeHeight(node);
    const nodeWithPosition = dagreGraph.node(node.id);
    node.targetPosition = 'top' as Position;
    node.sourcePosition = 'bottom' as Position;
    node.height = height;
    node.position = {
      x: nodeWithPosition.x - nodeWithPosition.width / 2,
      y: nodeWithPosition.y - nodeWithPosition.height / 2,
    };
  });

  Object.values(parentMap).forEach((siblings) => {
    if (siblings.length > 1) {
      const minY = Math.min(...siblings.map((sibling) => sibling.position.y));
      siblings.forEach((sibling) => {
        sibling.position.y = minY;
      });
    }
  });

  return { nodes, edges };
};
