import React, { useMemo } from 'react';
import { Table, TableBody, TableCell, TableContainer, TableRow, Paper } from '@mui/material';
import { styled } from '@mui/material/styles';

const StyledTableCell = styled(TableCell)(({ theme }) => ({
  borderBottom: `1px solid ${theme.palette.divider}`,
  borderRight: `1px solid ${theme.palette.divider}`,
  padding: theme.spacing(1),
  '&:last-child': {
    borderRight: 'none',
  }
}));

const processSchema = (schema) => {
  // Handle empty schema
  if (!schema || schema.length === 0) {
    return { matrix: [], spans: [], maxDepth: 0 };
  }

  // Handle flat schema (single level)
  const isFlatSchema = schema.every(node => !node.children || node.children.length === 0);
  if (isFlatSchema) {
    return {
      matrix: [schema.map(node => node.name)],
      spans: [schema.map(() => 1)],
      maxDepth: 1
    };
  }

  // Calculate the maximum depth of the schema
  const getDepth = (node) => {
    if (!node.children || node.children.length === 0) return 1;
    return 1 + Math.max(...node.children.map(child => getDepth(child)));
  };
  const maxDepth = Math.max(...schema.map(node => getDepth(node)));

  // Calculate total width needed
  const getWidth = (node) => {
    if (!node.children || node.children.length === 0) return 1;
    return node.children.reduce((sum, child) => sum + getWidth(child), 0);
  };
  const totalWidth = schema.reduce((sum, node) => sum + getWidth(node), 0);

  // Initialize the matrix and spans arrays
  const matrix = Array(maxDepth).fill().map(() => Array(totalWidth).fill(''));
  const spans = Array(maxDepth).fill().map(() => Array(totalWidth).fill(1));

  let currentCol = 0;

  const processNode = (node, depth = 0, startCol) => {
    if (!node.children || node.children.length === 0) {
      // Leaf node - should be at the bottom
      matrix[maxDepth - 1][startCol] = node.name;
      spans[maxDepth - 1][startCol] = 1;
      currentCol = startCol + 1;
      return;
    }

    if (node.name !== 'None') {
      // Add the parent node to the matrix
      matrix[depth][startCol] = node.name;
      const width = getWidth(node);
      spans[depth][startCol] = width;
      
      // Mark spanned columns
      for (let i = startCol + 1; i < startCol + width; i++) {
        spans[depth][i] = 0;
      }
    }

    // Process children
    let childCol = startCol;
    node.children.forEach(child => {
      processNode(child, depth + 1, childCol);
      childCol = currentCol;
    });
  };

  // Process each top-level node
  currentCol = 0;
  schema.forEach(node => {
    processNode(node, 0, currentCol);
  });

  // Remove empty rows at the top if they exist
  while (matrix.length > 0 && matrix[0].every(cell => cell === '')) {
    matrix.shift();
    spans.shift();
  }

  return { matrix, spans, maxDepth: matrix.length };
};

const SchemaPreview = ({ schema }) => {
  const { matrix, spans, maxDepth } = useMemo(() => processSchema(schema), [schema]);

  if (!schema || schema.length === 0) {
    return null;
  }

  return (
    <TableContainer component={Paper} sx={{ mt: 2, mb: 2 }}>
      <Table size="small">
        <TableBody>
          {matrix.map((row, rowIndex) => (
            <TableRow key={rowIndex}>
              {row.map((cell, colIndex) => {
                if (spans[rowIndex][colIndex] === 0) return null;
                return (
                  <StyledTableCell
                    key={colIndex}
                    colSpan={spans[rowIndex][colIndex]}
                    align="center"
                  >
                    {cell}
                  </StyledTableCell>
                );
              })}
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </TableContainer>
  );
};

export default SchemaPreview;