// TODO: finish the footer
import {
    Checkbox,
    IconButton,
    Paper,
    Stack,
    Table,
    TableBody,
    TableCell,
    TableContainer,
    TableFooter,
    TableHead,
    TableOwnProps,
    TablePagination,
    TableRow,
    TableSortLabel,
    styled,
} from "@mui/material"
import classnames from "classnames"
import React, { ReactNode, useCallback, useEffect, useMemo, useState } from "react"
import { ChevronDown, ChevronUp } from "react-feather"
import { useIntl } from "react-intl"

import { commonMessages } from "~/common-messages"
import { PAGE_SIZE } from "~/constants/pagination"
import { selectDataGridCommonState } from "~/store/global/globalSlice"
import { useAppSelector } from "~/store/hooks"
import { getComparator, stopPropagation, tableSort } from "~/utils"

import "./DataTable.scss"

type GetValueFunc<T> = (o: T) => string | number | null
type RenderCellFunc<T> = (o: T, index: number) => React.ReactNode

export type DataTableColumn<T> = {
    label: React.ReactNode
    id: string
    renderCell?: RenderCellFunc<T>
    sorter?: boolean
    style?: React.CSSProperties
    visible?: boolean
    subtitle?: React.ReactNode
    aggregationFunction?: (items: T[]) => React.ReactNode
} & (
    | {
          key: keyof T | string
          getValue?: undefined
      }
    | {
          key?: undefined
          getValue: GetValueFunc<T>
      }
)

type Props<T> = {
    data: T[] | undefined
    columns: DataTableColumn<T>[]
    handleClickRow?: (item: T) => void
    classNames?: string
    defaultOrderBy?: string
    defaultOrderDirection?: OrderDirection
    onRequestSort?: (orderBy: string, orderDirection: OrderDirection) => void
    hidePagination?: boolean
    withAggregatedFooter?: boolean
    keySelect?: keyof T | null
} & (
    | {
          selectable?: false
          defaultSelected?: never
          onSelect?: never
      }
    | {
          selectable: true
          defaultSelected: unknown[]
          onSelect: (selected: unknown[]) => void
      }
) &
    (
        | {
              treeDisplay: true
              getChildren: (item: T) => T[]
              isFiltered?: boolean
          }
        | {
              treeDisplay?: false | undefined
              getChildren?: never
              isFiltered?: never
          }
    )

export enum OrderDirection {
    ASC = "asc",
    DESC = "desc",
}

const getComparatorKey = <T,>(
    columns: DataTableColumn<T>[],
    id: string | undefined
): keyof T | GetValueFunc<T> | undefined => {
    const column = columns.find((c) => c.id === id)
    if (!column) return undefined
    return column.getValue || (column.key as keyof T)
}

const isColumnVisible = <T,>(column: DataTableColumn<T>): boolean => {
    return column.visible === undefined || column.visible !== false
}

const TableLabelSubtitle = styled("span")({
    fontSize: "var(--font-size-sm)",
    lineHeight: "var(--line-height-sm)",
    height: "var(--line-height-sm)",
    color: "var(--color-line-workflow)",
})

const Spacer = styled("div")({
    width: "var(--spacing-md)",
    height: "var(--spacing-md)",
})

const TABLE_NAME_KEY = "name"

export const DataTable = <T,>({
    data,
    columns,
    selectable,
    defaultSelected,
    onSelect,
    keySelect,
    handleClickRow,
    classNames,
    defaultOrderBy,
    defaultOrderDirection,
    onRequestSort,
    hidePagination,
    withAggregatedFooter,
    treeDisplay,
    getChildren,
    isFiltered,
}: Props<T>) => {
    const [page, setPage] = useState(0)
    const [rowsPerPage, setRowsPerPage] = useState(PAGE_SIZE)
    const [valueToOrderBy, setValueToOrderBy] = useState(defaultOrderBy)
    const [orderDirection, setOrderDirection] = useState(defaultOrderDirection ?? OrderDirection.DESC)
    const [selected, setSelected] = useState<unknown[]>(defaultSelected ?? [])
    const [expanded, setExpanded] = useState<Set<string>>(new Set())
    const { formatMessage } = useIntl()

    const gridCommonState = useAppSelector(selectDataGridCommonState)

    const tableSize: TableOwnProps["size"] = gridCommonState?.density === "compact" ? "small" : "medium"

    const classes = classnames(classNames, {
        "cursor-pointer": !!handleClickRow,
    })

    const countData = data?.length ?? 0
    columns = columns.filter(isColumnVisible)

    useEffect(() => {
        if (defaultSelected) {
            setSelected(defaultSelected)
        }
    }, [defaultSelected])

    const handleRequestSort = (property: string) => () => {
        const isAscending = valueToOrderBy === property && orderDirection === "asc"
        setValueToOrderBy(property)
        setOrderDirection(isAscending ? OrderDirection.DESC : OrderDirection.ASC)
        onRequestSort && onRequestSort(property, isAscending ? OrderDirection.DESC : OrderDirection.ASC)
    }

    const handleChangePage = (_event: unknown, newPage: number) => {
        setPage(newPage)
    }

    const handleChangeRowsPerPage = (event: React.ChangeEvent<HTMLInputElement>) => {
        setRowsPerPage(parseInt(event.target.value, 10))
        setPage(0)
    }

    const onSelectAllClick = useCallback(() => {
        if (selected.length === countData) {
            setSelected([])
            onSelect && onSelect([])
        } else {
            if (selectable && keySelect) {
                const newSelected = data ? data?.map((d) => d[keySelect]) : []
                setSelected(newSelected)
                onSelect && onSelect(newSelected)
            }
        }
    }, [selected, data, keySelect, onSelect, selectable, countData])

    const checkSelected = useCallback(
        (entityId: unknown) => () => {
            const index = selected.indexOf(entityId)
            if (index >= 0) {
                const newSelected = [...selected.slice(0, index), ...selected.slice(index + 1)]
                setSelected(newSelected)
                onSelect && onSelect(newSelected)
            } else {
                const newSelected = [...selected, entityId]
                setSelected(newSelected)
                onSelect && onSelect(newSelected)
            }
        },
        [selected, onSelect]
    )

    const getCellStyle = (column: DataTableColumn<T>) => {
        if (!column.style) {
            return undefined
        }

        const style: React.CSSProperties = {}
        if (column.style.maxWidth) {
            style.maxWidth = column.style.maxWidth
            style.textOverflow = "ellipsis"
            style.whiteSpace = "nowrap"
            style.overflow = "hidden"
        }

        return { ...column.style, ...style }
    }

    const renderHeaderLabel = (label: ReactNode, subtitle: ReactNode): ReactNode => (
        <Stack>
            <span>{label}</span>
            {!!subtitle && <TableLabelSubtitle>{subtitle}</TableLabelSubtitle>}
        </Stack>
    )

    const renderCellContent = (column: DataTableColumn<T>, dataLine: T, index: number) => {
        if (column.renderCell) {
            return column.renderCell(dataLine, index)
        } else if (column.getValue) {
            return column.getValue(dataLine)
        } else if (typeof column.key === "string") {
            return `${dataLine[column.key as keyof T]}`
        } else if (column.key) {
            return `${dataLine[column.key]}`
        }
        return null
    }

    const renderChildren = (children: T[], level = 1): React.ReactNode[] => {
        return children.map((child, index) => {
            // Type guard to check for parent ID
            const hasParentId = (item): item is { parentId: unknown } => "parentId" in item
            if (level === 1 && hasParentId(child)) {
                return null
            }

            // Type guard for children
            // FIXME: this is a temporary fix to avoid the error
            // eslint-disable-next-line @typescript-eslint/no-explicit-any
            const hasChildrenArray = (item: any): item is { children: T[] } =>
                // eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
                "children" in item && Array.isArray(item.children)
            const hasChildren = hasChildrenArray(child)
            const childKey = child[keySelect as keyof T]
            const isExpanded = expanded.has(childKey as string)

            return (
                <React.Fragment key={index}>
                    <TableRow>
                        {selectable && keySelect ? (
                            <TableCell align="left" className="select-multi">
                                <Checkbox
                                    checked={selected.indexOf(child[keySelect]) >= 0}
                                    onClick={stopPropagation}
                                    onChange={checkSelected(child[keySelect])}
                                />
                            </TableCell>
                        ) : null}
                        {updatedColumns.map((column) => (
                            <TableCell
                                key={column.id}
                                style={{
                                    paddingLeft: column.key === TABLE_NAME_KEY ? `${level * 40}px` : "16px",
                                }}
                            >
                                {renderCellContent(column, child, index)}
                            </TableCell>
                        ))}
                    </TableRow>
                    {hasChildren && isExpanded && renderChildren(child.children, level + 1)}
                </React.Fragment>
            )
        })
    }

    const recursiveCollectKeys = useCallback(
        (items: T[], allKeys: Set<string>) => {
            items.forEach((item) => {
                const key = item[keySelect as keyof T] as string
                allKeys.add(key)

                // If the item has children, recursively collect keys for them
                const children = getChildren?.(item) || []
                if (children.length) {
                    recursiveCollectKeys(children, allKeys)
                }
            })
        },
        [getChildren, keySelect]
    )

    const collectAllKeys = useCallback(
        (items: T[]): Set<string> => {
            const allKeys = new Set<string>()
            recursiveCollectKeys(items, allKeys)

            return allKeys
        },
        [recursiveCollectKeys]
    )

    useEffect(() => {
        if (isFiltered && data && keySelect && getChildren) {
            // If the table is filtered, we need to expand all the keys that are in the data
            const allKeys = collectAllKeys(data)
            setExpanded(allKeys)
        }
    }, [isFiltered, data, keySelect, getChildren, collectAllKeys])

    const handleToggleExpand = useCallback(
        (keySelectValue: string) => () => {
            const newExpanded = new Set(expanded)
            if (newExpanded.has(keySelectValue)) {
                newExpanded.delete(keySelectValue)
            } else {
                newExpanded.add(keySelectValue)
            }
            setExpanded(newExpanded)
        },
        [expanded]
    )

    // Update the columns based on treeDisplay flag
    const updatedColumns = useMemo(() => {
        return columns.map((column) => {
            if (treeDisplay && column.id === TABLE_NAME_KEY) {
                // Modify renderCell for the name column to include expand/collapse button
                return {
                    ...column,
                    renderCell: (dataLine: T, index: number) => {
                        return (
                            <div className="vertical-center">
                                {getChildren && getChildren(dataLine)?.length ? (
                                    <IconButton onClick={handleToggleExpand(dataLine[keySelect as keyof T] as string)}>
                                        {expanded.has(dataLine[keySelect as keyof T] as string) ? (
                                            <ChevronUp />
                                        ) : (
                                            <ChevronDown />
                                        )}
                                    </IconButton>
                                ) : (
                                    <Spacer />
                                )}
                                {column.renderCell ? (
                                    column.renderCell(dataLine, index)
                                ) : (
                                    <span>{dataLine[column.key as keyof T] as React.ReactNode}</span>
                                )}
                            </div>
                        )
                    },
                }
            }
            return column
        })
    }, [columns, treeDisplay, expanded, keySelect, handleToggleExpand, getChildren])

    const globalCheckboxIndeterminate = useMemo(
        () => !!selected.length && selected.length < countData,
        [selected, countData]
    )
    const globalCheckboxChecked = useMemo(
        () => !!selected.length && selected.length === countData,
        [selected, countData]
    )

    return (
        <>
            <TableContainer component={Paper} className={classes}>
                <Table stickyHeader size={tableSize}>
                    <TableHead>
                        <TableRow>
                            {selectable ? (
                                <TableCell align="left">
                                    <Checkbox
                                        indeterminate={globalCheckboxIndeterminate}
                                        checked={globalCheckboxChecked}
                                        onChange={onSelectAllClick}
                                    />
                                </TableCell>
                            ) : null}
                            {updatedColumns.map(({ id, sorter, label, subtitle }, i) => (
                                <TableCell align="left" key={i}>
                                    {sorter ? (
                                        <Stack>
                                            <TableSortLabel
                                                style={{ flexDirection: "row" }}
                                                active={valueToOrderBy === id}
                                                direction={orderDirection}
                                                onClick={handleRequestSort(id)}
                                            >
                                                <span>{label}</span>
                                            </TableSortLabel>
                                            {subtitle && <TableLabelSubtitle>{subtitle}</TableLabelSubtitle>}
                                        </Stack>
                                    ) : (
                                        renderHeaderLabel(label, subtitle)
                                    )}
                                </TableCell>
                            ))}
                        </TableRow>
                    </TableHead>
                    <TableBody>
                        {data &&
                            tableSort(
                                data,
                                getComparator(orderDirection, getComparatorKey(updatedColumns, valueToOrderBy))
                            )
                                .slice(page * rowsPerPage, page * rowsPerPage + rowsPerPage)
                                .map((dataLine, i) => (
                                    <React.Fragment key={i}>
                                        <TableRow
                                            onClick={() => handleClickRow && handleClickRow(dataLine)}
                                            hover={!!handleClickRow}
                                        >
                                            {selectable && keySelect ? (
                                                <TableCell align="left" className="select-multi">
                                                    <Checkbox
                                                        checked={selected.indexOf(dataLine[keySelect]) >= 0}
                                                        onClick={stopPropagation}
                                                        onChange={checkSelected(dataLine[keySelect])}
                                                    />
                                                </TableCell>
                                            ) : null}
                                            {updatedColumns.map((column) => (
                                                <TableCell align="left" key={column.id} sx={getCellStyle(column)}>
                                                    {renderCellContent(column, dataLine, i)}
                                                </TableCell>
                                            ))}
                                        </TableRow>
                                        {treeDisplay &&
                                        getChildren &&
                                        expanded.has(dataLine[keySelect as keyof T] as string)
                                            ? renderChildren(getChildren(dataLine))
                                            : null}
                                    </React.Fragment>
                                ))}
                    </TableBody>
                    {data && withAggregatedFooter ? (
                        <TableFooter>
                            <TableRow>
                                {updatedColumns.map((column, i) => (
                                    <TableCell align="left" key={`footer-${i}`} sx={getCellStyle(column)}>
                                        {column.aggregationFunction ? column.aggregationFunction(data) : null}
                                    </TableCell>
                                ))}
                            </TableRow>
                        </TableFooter>
                    ) : null}
                </Table>
                {!hidePagination ? (
                    <TablePagination
                        labelRowsPerPage={formatMessage(commonMessages.rowsPerPage)}
                        labelDisplayedRows={(pagination) => {
                            if (pagination.count === -1) {
                                return formatMessage(commonMessages.labelDisplayedRowsNoCount, {
                                    from: pagination.from,
                                    to: pagination.to,
                                })
                            }
                            return formatMessage(commonMessages.labelDisplayedRows, {
                                from: pagination.from,
                                to: pagination.to,
                                count: pagination.count,
                            })
                        }}
                        rowsPerPageOptions={[10, 25, 50, 100]}
                        component="div"
                        count={data?.length || 0}
                        rowsPerPage={rowsPerPage}
                        page={page}
                        onPageChange={handleChangePage}
                        onRowsPerPageChange={handleChangeRowsPerPage}
                    />
                ) : null}
            </TableContainer>
        </>
    )
}
