import "reactflow/dist/style.css";

import { css, cva } from "@styled-system/css";
import React, { memo, useCallback, useMemo } from "react";
import ReactFlow, {
  addEdge,
  Background,
  BackgroundVariant,
  Connection,
  Controls,
  Edge,
  Handle,
  MarkerType,
  Position,
  useEdgesState,
  useNodesState,
} from "reactflow";

const nodeContainer = cva({
  base: {
    border: "1px solid #777",
    borderRadius: "5px",
    backgroundColor: "#fff",
    boxShadow: "0 4px 6px rgba(0, 0, 0, 0.1)",
    fontSize: "12px",
    fontFamily: "Arial, sans-serif",
  },
});

const nodeHeader = cva({
  base: {
    backgroundColor: "#4a90e2",
    color: "white",
    padding: "8px 10px",
    borderTopLeftRadius: "4px",
    borderTopRightRadius: "4px",
    fontWeight: "bold",
    textAlign: "center",
  },
});

const nodeBody = cva({
  base: {
    padding: "10px",
  },
});

const nodeColumn = cva({
  base: {
    display: "flex",
    justifyContent: "space-between",
    padding: "4px 0",
    borderBottom: "1px solid #eee",
  },
});

const pkIndicator = cva({
  base: {
    color: "#888",
    marginLeft: "10px",
  },
});

const handle = cva({
  base: {
    background: "#555",
  },
});

interface TableNodeProps {
  data: {
    label: string;
    columns: string[];
  };
}

const TableNode: React.FC<TableNodeProps> = memo(({ data }) => (
  <div className={nodeContainer()} data-testid="table-node">
    <div className={nodeHeader()}>{data.label}</div>
    <div className={nodeBody()}>
      {data.columns.map((column, index) => (
        <div className={nodeColumn()} key={column}>
          <span>{column}</span>
          {index === 0 ? <span className={pkIndicator()}>PK</span> : null}
        </div>
      ))}
    </div>
    <Handle className={handle()} position={Position.Right} type="source" />
    <Handle className={handle()} position={Position.Left} type="target" />
  </div>
));

TableNode.displayName = "TableNode";

const nodeTypes = {
  tableNode: TableNode,
};

// SQLDiagram Component
interface SQLDiagramProps {
  tables: string[];
  columns: Record<string, string[]>;
  joins: {
    sourceTable: string;
    sourceColumn: string;
    targetTable: string;
    targetColumn: string;
  }[];
}

const SQLDiagram: React.FC<SQLDiagramProps> = ({ tables, columns, joins }) => {
  const initialNodes = useMemo(
    () =>
      tables.map((table, index) => ({
        id: table,
        type: "tableNode",
        data: {
          label: table,
          columns: columns[table],
        },
        position: { x: index * 300, y: index * 100 },
      })),
    [tables, columns]
  );

  const initialEdges = useMemo(
    () =>
      joins.map((join, index) => ({
        id: `e${index}`,
        source: join.sourceTable,
        target: join.targetTable,
        label: `${join.targetColumn} → ${join.sourceColumn}`,
        type: "smoothstep",
        animated: true,
        style: { stroke: "#4a90e2" },
        markerEnd: { type: MarkerType.ArrowClosed, color: "#4a90e2" },
      })),
    [joins]
  );

  const [nodes, setNodes, onNodesChange] = useNodesState(initialNodes);
  const [edges, setEdges, onEdgesChange] = useEdgesState(initialEdges);

  const onConnect = useCallback(
    (params: Edge | Connection) => setEdges((eds) => addEdge(params, eds)),
    [setEdges]
  );

  return (
    <div className={css({ width: "100%", height: "100%" })}>
      <ReactFlow
        edges={edges}
        elementsSelectable={false}
        fitView
        id="sql-diagram"
        nodeTypes={nodeTypes}
        nodes={nodes}
        nodesConnectable={false}
        onConnect={onConnect}
        onEdgesChange={onEdgesChange}
        onNodesChange={onNodesChange}
        tabIndex={0}
      >
        <Controls />
        <Background gap={12} size={1} variant={BackgroundVariant.Dots} />
      </ReactFlow>
    </div>
  );
};

export default SQLDiagram;
