import { findParentNode } from "@tiptap/core";
// import { Selection, Transaction } from "@tiptap/pm/state";
import { CellSelection, TableMap } from "@tiptap/pm/tables";
// import { Node, ResolvedPos } from "@tiptap/pm/model";

export const isRectSelected = (rect) => (selection) => {
	const map = TableMap.get(selection.$anchorCell.node(-1));
	const start = selection.$anchorCell.start(-1);
	const cells = map.cellsInRect(rect);
	const selectedCells = map.cellsInRect(
		map.rectBetween(selection.$anchorCell.pos - start, selection.$headCell.pos - start)
	);

	for (let i = 0, count = cells.length; i < count; i += 1) {
		if (selectedCells.indexOf(cells[i]) === -1) {
			return false;
		}
	}

	return true;
};

export const findTable = (selection) =>
	findParentNode((node) => node.type.spec.tableRole && node.type.spec.tableRole === "table")(selection);

export const isCellSelection = (selection) => selection instanceof CellSelection;

export const isColumnSelected = (columnIndex) => (selection) => {
	if (isCellSelection(selection)) {
		const map = TableMap.get(selection.$anchorCell.node(-1));
		return isRectSelected({
			left: columnIndex,
			right: columnIndex + 1,
			top: 0,
			bottom: map.height
		})(selection);
	}
	return false;
};

export const isRowSelected = (rowIndex) => (selection) => {
	if (isCellSelection(selection)) {
		const map = TableMap.get(selection.$anchorCell.node(-1));
		return isRectSelected({
			left: 0,
			right: map.width,
			top: rowIndex,
			bottom: rowIndex + 1
		})(selection);
	}
	return false;
};

export const isTableSelected = (selection) => {
	if (isCellSelection(selection)) {
		const map = TableMap.get(selection.$anchorCell.node(-1));
		return isRectSelected({
			left: 0,
			right: map.width,
			top: 0,
			bottom: map.height
		})(selection);
	}
	return false;
};

export const getCellsInColumn = (columnIndex) => (selection) => {
	const table = findTable(selection);
	if (!table) return null;

	const map = TableMap.get(table.node);
	const indexes = Array.isArray(columnIndex) ? columnIndex : [columnIndex];
	return indexes.reduce((acc, index) => {
		if (index >= 0 && index < map.width) {
			const cells = map.cellsInRect({
				left: index,
				right: index + 1,
				top: 0,
				bottom: map.height
			});
			return acc.concat(
				cells.map((nodePos) => {
					const node = table.node.nodeAt(nodePos);
					const pos = nodePos + table.start;
					return { pos, start: pos + 1, node };
				})
			);
		}
		return acc;
	}, []);
};

export const getCellsInRow = (rowIndex) => (selection) => {
	const table = findTable(selection);
	if (!table) return null;

	const map = TableMap.get(table.node);
	const indexes = Array.isArray(rowIndex) ? rowIndex : [rowIndex];
	return indexes.reduce((acc, index) => {
		if (index >= 0 && index < map.height) {
			const cells = map.cellsInRect({
				left: 0,
				right: map.width,
				top: index,
				bottom: index + 1
			});
			return acc.concat(
				cells.map((nodePos) => {
					const node = table.node.nodeAt(nodePos);
					const pos = nodePos + table.start;
					return { pos, start: pos + 1, node };
				})
			);
		}
		return acc;
	}, []);
};

export const getCellsInTable = (selection) => {
	const table = findTable(selection);
	if (!table) return null;

	const map = TableMap.get(table.node);
	const cells = map.cellsInRect({
		left: 0,
		right: map.width,
		top: 0,
		bottom: map.height
	});
	return cells.map((nodePos) => {
		const node = table.node.nodeAt(nodePos);
		const pos = nodePos + table.start;
		return { pos, start: pos + 1, node };
	});
};

export const findParentNodeClosestToPos = ($pos, predicate) => {
	for (let i = $pos.depth; i > 0; i -= 1) {
		const node = $pos.node(i);
		if (predicate(node)) {
			return {
				pos: i > 0 ? $pos.before(i) : 0,
				start: $pos.start(i),
				depth: i,
				node
			};
		}
	}
	return null;
};

export const findCellClosestToPos = ($pos) => {
	const predicate = (node) => node.type.spec.tableRole && /cell/i.test(node.type.spec.tableRole);
	return findParentNodeClosestToPos($pos, predicate);
};

const select = (type) => (index) => (tr) => {
	const table = findTable(tr.selection);
	if (!table) return tr;

	const map = TableMap.get(table.node);
	const isRowSelection = type === "row";
	const left = isRowSelection ? 0 : index;
	const top = isRowSelection ? index : 0;
	const right = isRowSelection ? map.width : index + 1;
	const bottom = isRowSelection ? index + 1 : map.height;

	const cellsInFirstRow = map.cellsInRect({ left, top, right, bottom });
	const head = table.start + cellsInFirstRow[0];
	const anchor = table.start + cellsInFirstRow[cellsInFirstRow.length - 1];
	const $head = tr.doc.resolve(head);
	const $anchor = tr.doc.resolve(anchor);

	tr.setSelection(new CellSelection($anchor, $head));
	return tr;
};

export const selectColumn = select("column");
export const selectRow = select("row");
export const selectTable = (tr) => {
	const table = findTable(tr.selection);
	if (!table) return tr;

	const map = TableMap.get(table.node);
	const cells = map.cellsInRect({
		left: 0,
		right: map.width,
		top: 0,
		bottom: map.height
	});
	const head = table.start + cells[0];
	const anchor = table.start + cells[cells.length - 1];
	const $head = tr.doc.resolve(head);
	const $anchor = tr.doc.resolve(anchor);

	tr.setSelection(new CellSelection($anchor, $head));
	return tr;
};
