import { CSSProperties, useState } from 'react';
import { Table, TableHead, TableBody, TableRow, TableSortLabel } from '@mui/material';
import { makeStyles } from 'tss-react/mui';
import { KeyboardArrowDown } from '@mui/icons-material';
import {
  SortType,
  SortState,
  DashboardTableHeadProps,
  DashboardTableColumnDef,
} from '../../types/dashboard-table-types';
import { countColLeaves } from '../../utils/dashboard-common-utils';
import DashboardTableCell from './DashboardTableCell';

const useStyles = makeStyles()((theme) => ({
  headerCell: {
    fontSize: 12,
    fontWeight: 500,
    fontFamily: 'Arial',
    backgroundColor: '#FFFFFF',
    left: 'auto',
  },
  multiCellContainer: {
    backgroundColor: '#FFFFFF',
    padding: 0,
    borderBottom: 'none',
    left: 'auto',
  },
  borderTop: {
    borderTop: '1px solid #F0F0F0',
  },
  divider: {
    borderRight: '1px solid #F0F0F0',
  },
}));

const DashboardTableHead = <T extends Record<string, unknown>>({ columnDef }: DashboardTableHeadProps<T>) => {
  const { classes } = useStyles();
  const [sortState, setSortState] = useState<SortState>({});

  const toggleDirection = (direction: SortType | undefined) => {
    if (direction === SortType.ASC) {
      return SortType.DESC;
    } else if (direction === SortType.DESC) {
      return undefined;
    } else {
      return SortType.ASC;
    }
  };

  const handleSort = (column: DashboardTableColumnDef<T>) => {
    if (column.sortable) {
      const newSortState: SortState = {
        ...sortState,
        [column.keyIndex]: toggleDirection(sortState[column.keyIndex]),
      };
      setSortState(newSortState);
      if (column.onSort) {
        column.onSort(column.keyIndex, newSortState);
      }
    }
  };

  const renderMultiRowHeader = (column: DashboardTableColumnDef<T>, isLast?: boolean) => {
    if (column.columnChildren) {
      const numOfChildren = column.columnChildren.length;
      const colSpan = countColLeaves(column);
      return (
        <DashboardTableCell
          key={`table-head-${column.keyIndex}`}
          colSpan={colSpan}
          align={'center'}
          className={`${classes.multiCellContainer} ${classes.borderTop} ${!isLast ? classes.divider : undefined}`}
          style={column.headerStyle}
        >
          <Table>
            <TableBody>
              <TableRow>
                <DashboardTableCell colSpan={colSpan} align={'center'} className={classes.headerCell}>
                  <span style={column.headerTextStyle}>{column.displayName}</span>
                </DashboardTableCell>
              </TableRow>
              <TableRow>
                {column.columnChildren.map((child, index) => renderMultiRowHeader(child, index === numOfChildren - 1))}
              </TableRow>
            </TableBody>
          </Table>
        </DashboardTableCell>
      );
    } else {
      return renderHeaderCell(column, `${classes.headerCell} ${!isLast ? classes.divider : undefined}`);
    }
  };

  const renderHeaderCell = (column: DashboardTableColumnDef<T>, className: string) => {
    let style: CSSProperties = { ...column.headerStyle };
    if (column.sortable) {
      style = { ...style, cursor: 'pointer', whiteSpace: 'nowrap' };
    }
    return (
      <DashboardTableCell
        key={`table-head-${column.keyIndex}`}
        align={column.align}
        className={className}
        style={style}
        onClick={() => handleSort(column)}
      >
        <span style={column.headerTextStyle}>{column.displayName}</span>
        {column.sortable &&
          (sortState[column.keyIndex] !== undefined ? (
            <TableSortLabel
              active={true}
              direction={sortState[column.keyIndex]}
              IconComponent={(props) => <KeyboardArrowDown {...props} style={{ color: '#000000' }} />}
            />
          ) : (
            <TableSortLabel
              active={true}
              direction={'desc'}
              IconComponent={(props) => <KeyboardArrowDown {...props} style={{ color: '#999999' }} />}
            />
          ))}
      </DashboardTableCell>
    );
  };

  return (
    <TableRow>
      {columnDef
        .filter((column) => !column.hidden)
        .map((column, index) =>
          !!column.columnChildren
            ? renderMultiRowHeader(column, index === columnDef.length - 1)
            : renderHeaderCell(
                column,
                `${classes.headerCell} ${classes.borderTop} ${column.hasDivider ? classes.divider : undefined}`,
              ),
        )}
    </TableRow>
  );
};

export default DashboardTableHead;
