import {
  ConditionType,
  DatasetRowId,
  DecisionTree,
  DecisionTreeAndNode,
  DecisionTreeAndNodeResult,
  DecisionTreeArrow,
  DecisionTreeAtLeastNNode,
  DecisionTreeAtLeastNNodeResult,
  DecisionTreeAtMostXNode,
  DecisionTreeAtMostXNodeResult,
  DecisionTreeChartNode,
  DecisionTreeChartNodeResult,
  DecisionTreeCommonNode,
  DecisionTreeCommonNodeResult,
  DecisionTreeDateNode,
  DecisionTreeDateNodeResult,
  DecisionTreeNode,
  DecisionTreeNodeResult,
  DecisionTreeOrNode,
  DecisionTreeOrNodeResult,
  DecisionTreePropertyNode,
  DecisionTreePropertyNodeOutputCounts,
  DecisionTreePropertyNodeResult,
  DecisionTreeResult,
  DecisionTreeStructSearchNode,
  DecisionTreeStructureNodeResult,
  DesirabilityFunctionCondition,
  DiscreteCondition,
  DTArrowId,
  DTNodeId,
  DTNodeType,
  DTOperation,
  isAndNode,
  isAtLeastNNode,
  isAtMostXNode,
  isChartNode,
  isDateNode,
  isGroupNode,
  isOrNode,
  isPropertyNode,
  isStructSearchNode,
  PortGroupingType,
  RangeCondition,
  SCORING_FUNCTIONS_RULES,
  SimpleCondition,
} from '@discngine/moosa-models';
import dayjs, { Dayjs } from 'dayjs';
import customParseFormat from 'dayjs/plugin/customParseFormat';

import { parseUtcDatetime } from '../utils';

import { sortDecisionTreeTopologically } from './sortDecisionTreeTopologically';
import { DecisionTreeCellValuesGetter } from './types';

dayjs.extend(customParseFormat);

const getNodesMap = (dt: DecisionTree) => {
  return dt.nodes.reduce<Record<DTNodeId, DecisionTreeNode>>((acc, node) => {
    acc[node.id] = node;

    return acc;
  }, {});
};

const getArrowsMap = (dt: DecisionTree) => {
  return dt.arrows.reduce<Record<DTArrowId, DecisionTreeArrow>>((acc, arrow) => {
    acc[arrow.id] = arrow;

    return acc;
  }, {});
};

const compareSimpleValues = (value: number, condition: SimpleCondition): boolean => {
  switch (condition.operation) {
    case DTOperation.Less:
      return value < condition.threshold;
    case DTOperation.LessEqual:
      return value <= condition.threshold;
    case DTOperation.Greater:
      return value > condition.threshold;
    case DTOperation.GreaterEqual:
      return value >= condition.threshold;
    default:
      throw new Error('unexpected operator');
  }
};

const compareRangeValues = (value: number, condition: RangeCondition): boolean => {
  if (condition.min.operation === DTOperation.Less) {
    if (condition.max.operation === DTOperation.Less) {
      return value > condition.min.threshold && value < condition.max.threshold;
    }

    if (condition.max.operation === DTOperation.LessEqual) {
      return value > condition.min.threshold && value <= condition.max.threshold;
    }
  }

  if (condition.min.operation === DTOperation.LessEqual) {
    if (condition.max.operation === DTOperation.Less) {
      return value >= condition.min.threshold && value < condition.max.threshold;
    }

    if (condition.max.operation === DTOperation.LessEqual) {
      return value >= condition.min.threshold && value <= condition.max.threshold;
    }
  }

  throw new Error('unexpected operator');
};

const compareDesirabilityValues = (
  value: number | string,
  condition: DesirabilityFunctionCondition
): boolean => {
  const normalizedValue = SCORING_FUNCTIONS_RULES[condition.desirability.type].getValue(
    value,
    condition.desirability.functionParams[condition.desirability.type] as any
  );

  switch (condition.operation) {
    case DTOperation.Less:
      return normalizedValue < condition.threshold;
    case DTOperation.LessEqual:
      return normalizedValue <= condition.threshold;
    case DTOperation.Greater:
      return normalizedValue > condition.threshold;
    case DTOperation.GreaterEqual:
      return normalizedValue >= condition.threshold;
    default:
      throw new Error('unexpected operator');
  }
};

const compareDiscreteValues = (
  value: string | number,
  condition: DiscreteCondition
): boolean => {
  return (
    condition.values.find((discreteValue) => discreteValue.value === String(value))
      ?.isSelected || false
  );
};

const compareValues = (
  value: number | string,
  condition: DecisionTreePropertyNode['condition']
): boolean => {
  if (!condition) {
    return false;
  }

  if (condition.type === ConditionType.Discrete) {
    return compareDiscreteValues(value, condition);
  }

  if (condition.type === ConditionType.DesirabilityFunction) {
    return compareDesirabilityValues(value, condition);
  }

  if (typeof value === 'number') {
    if (condition.type === ConditionType.Simple) {
      return compareSimpleValues(value, condition);
    }

    if (condition.type === ConditionType.Range) {
      return compareRangeValues(value, condition);
    }
  }

  return false;
};

const compareSimpleDateValues = (
  value: Dayjs,
  threshold: number, // must be start of UTC day
  operator: DTOperation
): boolean => {
  switch (operator) {
    case DTOperation.Greater:
      return value.valueOf() > parseUtcDatetime(threshold).endOf('d').valueOf();
    case DTOperation.GreaterEqual:
      return value.valueOf() >= parseUtcDatetime(threshold).startOf('d').valueOf();
    case DTOperation.Less:
      return value.valueOf() < parseUtcDatetime(threshold).startOf('d').valueOf();
    case DTOperation.LessEqual:
      return value.valueOf() <= parseUtcDatetime(threshold).endOf('d').valueOf();
  }
};

/**
 * Date range condition is always 'from...to' includingly
 */
const compareRangeDateValues = (value: Dayjs, condition: RangeCondition): boolean => {
  return (
    compareSimpleDateValues(value, condition.min.threshold, DTOperation.GreaterEqual) &&
    compareSimpleDateValues(value, condition.max.threshold, DTOperation.LessEqual)
  );
};

/**
 * Both value and condition threshold are treated in local user time zone!
 * @param value
 * @param condition
 */
export const compareDateValues = (
  value: Dayjs,
  condition: DecisionTreeDateNode['condition']
): boolean | null => {
  if (!value.isValid()) {
    // Treat invalid dates as missing values
    return null;
  }

  if (!condition) {
    return false;
  }

  if (condition.type === ConditionType.Simple) {
    return compareSimpleDateValues(value, condition.threshold, condition.operation);
  }

  if (condition.type === ConditionType.Range) {
    return compareRangeDateValues(value, condition);
  }

  return false;
};

/**
 * Add rowId to a corresponding outputDatasetRowIds field based on 'check'
 */
const updateNodeResults = (
  outputDatasetRowIds: DecisionTreePropertyNodeResult['outputDatasetRowIds'],
  portGroupingType: PortGroupingType,
  rowId: string,
  check: boolean | null
): void => {
  switch (check) {
    case true:
      outputDatasetRowIds.yes.add(rowId);
      break;
    case false:
      outputDatasetRowIds.no.add(rowId);
      break;
    case null: // missing value
      switch (portGroupingType) {
        case PortGroupingType.Regular:
          outputDatasetRowIds.missingValues.add(rowId);
          break;
        case PortGroupingType.TrueMissing:
          outputDatasetRowIds.yes.add(rowId);
          break;
        case PortGroupingType.FalseMissing:
          outputDatasetRowIds.no.add(rowId);
          break;
        case PortGroupingType.HideMissing:
          break;
      }
  }
};

const getPropertyNodeOutputRowIds = (
  getCellValues: DecisionTreeCellValuesGetter,
  node: DecisionTreePropertyNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreePropertyNodeResult => {
  const rowIdsSet = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreePropertyNodeResult = {
    id: node.id,
    type: DTNodeType.Property,
    inputDatasetRowIds: rowIdsSet,
    outputDatasetRowIds: {
      yes: new Set(),
      no: new Set(),
      missingValues: new Set(),
    },
  };

  const rowIds = Array.from(rowIdsSet);

  const values = getCellValues(rowIds, node.propertyId);

  if (
    values.type !== 'string' &&
    values.type !== 'numeric' &&
    values.type !== 'columnMissing'
  ) {
    console.error(`Values from column ${node.propertyId} are not strings or numbers`);

    return nodeResult;
  }

  const groupingType = node.portGroupingType ?? PortGroupingType.Regular;

  for (let i = 0; i < rowIds.length; i++) {
    const rowId = rowIds[i];

    if (values.type === 'columnMissing') {
      updateNodeResults(nodeResult.outputDatasetRowIds, groupingType, rowId, null);
    } else {
      const value = values.values[i];
      const check = value == null ? null : compareValues(value, node.condition);

      updateNodeResults(nodeResult.outputDatasetRowIds, groupingType, rowId, check);
    }
  }

  return nodeResult;
};

const getDateNodeOutputRowIds = (
  getCellValues: DecisionTreeCellValuesGetter,
  node: DecisionTreeDateNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeDateNodeResult => {
  const rowIdsSet = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreeDateNodeResult = {
    id: node.id,
    type: DTNodeType.Date,
    inputDatasetRowIds: rowIdsSet,
    outputDatasetRowIds: {
      yes: new Set(),
      no: new Set(),
      missingValues: new Set(),
    },
  };

  const rowIds = Array.from(rowIdsSet);
  const values = getCellValues(rowIds, node.propertyId);

  if (values.type !== 'date' && values.type !== 'columnMissing') {
    console.error(`Values from column ${node.propertyId} are not date type`);

    return nodeResult;
  }

  const groupingType = node.portGroupingType ?? PortGroupingType.Regular;

  for (let i = 0; i < rowIds.length; i++) {
    const rowId = rowIds[i];

    if (values.type === 'columnMissing') {
      updateNodeResults(nodeResult.outputDatasetRowIds, groupingType, rowId, null);
    } else {
      const value = values.values[i];
      const check = value == null ? null : compareDateValues(value, node.condition);

      updateNodeResults(nodeResult.outputDatasetRowIds, groupingType, rowId, check);
    }
  }

  return nodeResult;
};

const getStructureNodeOutputRowIds = (
  node: DecisionTreeStructSearchNode,
  inputRows: Set<DatasetRowId>[],
  /**
   * The function checks that the core structure is a substructure of the structure in corresponding column cell
   *
   * @param rowId
   * @param columnId - It is expected that columnId has chemical structures in the cells
   * @param core - a structure in mol2 format to compare with
   *
   * @returns - should return null for missing values
   */
  checkSubstructure: (rowId: string, columnId: string, core: string) => boolean | null
): DecisionTreeStructureNodeResult => {
  const rowIdsSet = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreeStructureNodeResult = {
    id: node.id,
    type: DTNodeType.StructureSearch,
    inputDatasetRowIds: rowIdsSet,
    outputDatasetRowIds: {
      yes: new Set(),
      no: new Set(),
      missingValues: new Set(),
    },
  };

  const rowIds = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  for (const rowId of rowIds) {
    const isSubstructure = node.structure
      ? checkSubstructure(rowId, node.propertyId, node.structure)
      : true;

    updateNodeResults(
      nodeResult.outputDatasetRowIds,
      node.portGroupingType ?? PortGroupingType.Regular,
      rowId,
      isSubstructure
    );
  }

  return nodeResult;
};

const getOrNodeOutputRowIds = (
  node: DecisionTreeOrNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeOrNodeResult => {
  const result = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreeOrNodeResult = {
    id: node.id,
    type: DTNodeType.Or,
    inputDatasetRowIds: result,
    outputDatasetRowIds: {
      combine: result,
    },
  };

  return nodeResult;
};

const getChartNodeOutputRowIds = (
  node: DecisionTreeChartNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeChartNodeResult => {
  /**
   * For multiple arrows it works the same as OR node - just combine all results
   */
  const result = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreeChartNodeResult = {
    id: node.id,
    type: DTNodeType.Chart,
    inputDatasetRowIds: result,
    outputDatasetRowIds: {
      combine: result,
    },
  };

  return nodeResult;
};

/**
 * method mutates decision tree AND type node
 */
const getAndNodeOutputRowIds = (
  node: DecisionTreeAndNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeAndNodeResult => {
  const [first, ...others] = inputRows;
  let result = first ?? new Set<DatasetRowId>();

  for (const set of others) {
    const intermediate = new Set<DatasetRowId>();

    for (const rowId of set.values()) {
      if (result.has(rowId)) {
        intermediate.add(rowId);
      }
    }
    result = intermediate;
  }

  const all = inputRows.reduce<Set<DatasetRowId>>((res, set) => {
    set.forEach((rowId) => res.add(rowId));

    return res;
  }, new Set());

  const nodeResult: DecisionTreeAndNodeResult = {
    id: node.id,
    type: DTNodeType.And,
    inputDatasetRowIds: all,
    outputDatasetRowIds: {
      combine: result,
    },
  };

  return nodeResult;
};

/**
 * method mutates decision tree AND type node
 */
const getAtLeastNNodeOutputRowIds = (
  node: DecisionTreeAtLeastNNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeAtLeastNNodeResult => {
  const inputDatasetRowIds: Set<DatasetRowId> = new Set();
  const result: any = {};

  for (const set of inputRows) {
    for (const rowId of set.values()) {
      result[rowId] = (result[rowId] || 0) + 1;
      inputDatasetRowIds.add(rowId);
    }
  }
  const filtered = Object.keys(result).filter((i) => result[i] >= node.nodeN);
  const outputDatasetRowIds = new Set(filtered);

  return {
    id: node.id,
    type: DTNodeType.ALN,
    inputDatasetRowIds,
    outputDatasetRowIds: {
      combine: outputDatasetRowIds,
    },
  };
};

/**
 * method mutates decision tree AND type node
 */
const getAtMostXNodeOutputRowIds = (
  node: DecisionTreeAtMostXNode,
  inputRows: Set<DatasetRowId>[]
): DecisionTreeAtMostXNodeResult => {
  const inputDatasetRowIds: Set<DatasetRowId> = new Set();
  const result: any = {};

  for (const set of inputRows) {
    for (const rowId of set.values()) {
      result[rowId] = (result[rowId] || 0) + 1;
      inputDatasetRowIds.add(rowId);
    }
  }
  const filtered = Object.keys(result).filter((i) => result[i] <= node.nodeN);
  const outputDatasetRowIds = new Set(filtered);

  return {
    id: node.id,
    type: DTNodeType.AMX,
    inputDatasetRowIds,
    outputDatasetRowIds: {
      combine: outputDatasetRowIds,
    },
  };
};

/**
 * Return set of rowIds from a specific slot of a parent node
 * @param node - a parent node
 * @param nodeResult - node results of the parent node (node.id === nodeResult.id)
 * @param arrowId - arrow from parent node to current. It defines output slot name of the parent node
 */
const getOutputRowIds = (
  node: DecisionTreeCommonNode,
  nodeResult: DecisionTreeCommonNodeResult,
  arrowId: DTArrowId
): Set<DatasetRowId> => {
  const slots = Object.keys(node.outputArrows);
  const arrowSlot = slots.find((slot) =>
    node.outputArrows[slot]?.some((slotArrowId) => slotArrowId === arrowId)
  );

  if (!arrowSlot) {
    throw new Error(`cannot find slot for ${node.id}, ${arrowId}`);
  }
  const rowIds = nodeResult.outputDatasetRowIds[arrowSlot];

  if (!rowIds) {
    throw new Error(`cannot find rowIds for ${node.id}, ${arrowId}, ${arrowSlot}`);
  }

  return rowIds;
};

export const calculateOutputCounts = (
  getCellValues: DecisionTreeCellValuesGetter,
  datasetRowIds: DatasetRowId[],
  node: DecisionTreePropertyNode,
  nodesResult: DecisionTreeResult
): DecisionTreePropertyNodeOutputCounts => {
  const datasetRows = new Set<DatasetRowId>(datasetRowIds);
  const inputRowIds: Set<DatasetRowId> = !node.inputArrows.length
    ? datasetRows
    : nodesResult.nodes[node.id].inputDatasetRowIds;

  const propertyNodeResult = getPropertyNodeOutputRowIds(getCellValues, node, [
    inputRowIds,
  ]);

  return {
    yes: propertyNodeResult.outputDatasetRowIds.yes.size,
    no: propertyNodeResult.outputDatasetRowIds.no.size,
    missingValues: propertyNodeResult.outputDatasetRowIds.missingValues.size,
  };
};

export const calculateDecisionTreeResults = (
  getCellValues: DecisionTreeCellValuesGetter,
  datasetRowIds: DatasetRowId[],
  dt: DecisionTree,
  checkSubstructure: (
    rowId: string,
    columnId: string,
    core: string
  ) => boolean | null = () => true
): DecisionTreeResult => {
  const { nodeIds: sortedNodes, isCyclicGraph } = sortDecisionTreeTopologically(dt);

  const nodesResult: Record<DTNodeId, DecisionTreeNodeResult> = {};

  const nodesMap = getNodesMap(dt);
  const arrowsMap = getArrowsMap(dt);

  for (const nodeId of sortedNodes) {
    const node = nodesMap[nodeId]!;

    const inputRowIds: Set<DatasetRowId>[] = !node.inputArrows.length
      ? [new Set(datasetRowIds)]
      : node.inputArrows.map((inputArrowId) => {
          const arrow = arrowsMap[inputArrowId];
          const parentNodeId = arrow.from;
          const parentNode = nodesMap[parentNodeId];
          const parentResultNode = nodesResult[parentNodeId];
          const parentNodeOutputRowIds = getOutputRowIds(
            parentNode,
            parentResultNode,
            inputArrowId
          );

          return parentNodeOutputRowIds;
        });

    if (isAndNode(node)) {
      nodesResult[nodeId] = getAndNodeOutputRowIds(node, inputRowIds);
    } else if (isAtLeastNNode(node)) {
      nodesResult[nodeId] = getAtLeastNNodeOutputRowIds(node, inputRowIds);
    } else if (isAtMostXNode(node)) {
      nodesResult[nodeId] = getAtMostXNodeOutputRowIds(node, inputRowIds);
    } else if (isOrNode(node)) {
      nodesResult[nodeId] = getOrNodeOutputRowIds(node, inputRowIds);
    } else if (isChartNode(node)) {
      nodesResult[nodeId] = getChartNodeOutputRowIds(node, inputRowIds);
    } else if (isPropertyNode(node)) {
      nodesResult[nodeId] = getPropertyNodeOutputRowIds(getCellValues, node, inputRowIds);
    } else if (isDateNode(node)) {
      nodesResult[nodeId] = getDateNodeOutputRowIds(getCellValues, node, inputRowIds);
    } else if (isStructSearchNode(node)) {
      nodesResult[nodeId] = getStructureNodeOutputRowIds(
        node,
        inputRowIds,
        checkSubstructure
      );
    } else if (isGroupNode(node)) {
    } else {
      ((x: never) => {
        throw new Error('Unexpected node type');
      })(node);
    }
  }

  const result: DecisionTreeResult = {
    datasetRowIds: new Set(datasetRowIds),
    nodes: nodesResult,
    isCyclicGraph,
  };

  return result;
};
