import { BarSeriesType, LineSeriesType } from '@mui/x-charts';
import { QueryKey, useSuspenseQueries, useSuspenseQuery } from '@tanstack/react-query';
import { TFunction } from 'i18next';
import isEqual from 'lodash.isequal';
import { useReducer } from 'react';
import { match } from 'ts-pattern';
import { useDebounce } from 'use-debounce';
import { useAliases } from '../../../../api/alias/hooks';
import { MetricDetails } from '../../../../api/metrics/types';
import { MetricResult } from '../../../../api/types';
import {
  CohortMetricId,
  EmployeeCohortMetricId,
  MetricGroupId,
  MetricId,
  RegularMetricId,
} from '../../../../api/types-graphql';
import {
  getMetricTypeFromMetrics,
  metricIdNeedsEmptyCohortFilter,
  toMetricResultOrError,
  toSegmentation,
} from '../../../../api/utils';
import { Languages } from '../../../../constants';
import {
  useAliasServiceContext,
  useGlobalLocaleContext,
  useMetricDetailsMapContext,
  useMetricServiceContext,
} from '../../../../context/contexts';
import { GlobalDisplayState } from '../../../../context/types';
import {
  Colors,
  DataFieldWithDataType,
  MetricIdType,
  MetricIdTypeArrays,
  Segmentation,
  TimeSelection,
} from '../../../../types';
import { getFilterConditions, mergeSegmentations, removeDuplicates } from '../../../utils';
import {
  AreaChartConfig,
  ChartTypeConfig,
  LineChartConfig,
  LineChartQueryConfig,
  OverTimeBarChartConfig,
  OverTimeBarChartQueryConfig,
  OverTimeTableConfig,
} from '../../dashboards/types';
import { Filter, Segment } from '../../filter/filterbar/types';
import { ItemsState } from '../../filter/types';
import { NonSegmentedTableOverTimeData } from '../../tableview/types';
import {
  NonSegmentedOverTimeTableDataColumnWiseMonoid,
  NonSegmentedOverTimeTableDataMonoid,
} from '../../tableview/utils';
import { TimeSliderState } from '../../timeslider/types';
import { MuiOverTimeChartData } from '../mui-charts/types';
import {
  DisplayState,
  GlobalOverTimeDisplayAction,
  LocalOverTimeDisplayAction,
  OverTimeDisplayHandle,
  QueryConfig,
  ToolsAction,
  UseMetricQueriesReturnType,
} from './types';
import { timeSliderStateToOverTimeTimeSelection } from './utils';

const isStackedBarChart = (state: DisplayState) => {
  return match(state.chartTypeConfig)
    .with({ chartType: 'bar' }, () => state.metrics.length === 1 && !!state.segmentations)
    .otherwise(() => false);
};

export const useDisplay = (
  chartTypeConfig: ChartTypeConfig,
  segmentations: Segmentation[] | undefined,
  metrics: MetricIdType[],
  displayGlobalState: GlobalDisplayState
): OverTimeDisplayHandle => {
  const initialState: DisplayState = {
    enableLabels: true,
    showLabels: false,
    enablePercentage: false,
    showPercentage: false,
    enableStack: false,
    showStack: false,
    showTableView: false,
    enableLegend: true,
    showLegend: false,
    segmentations,
    metrics,
    chartTypeConfig,
    globalDisplayState: displayGlobalState,
  };

  return useReducer(reducer, initialState);
};

export const reducer = (
  state: DisplayState,
  action: GlobalOverTimeDisplayAction | LocalOverTimeDisplayAction | ToolsAction
): DisplayState => {
  const stackedBarChart = isStackedBarChart(state);
  return match(action)
    .with({ type: 'toggle-tools' }, (a: ToolsAction) => {
      const stackedBarChart = isStackedBarChart({ ...state, segmentations: a.segmentations });
      const enablePercentage = stackedBarChart && state.globalDisplayState.enablePercentage;
      const showPercentage = stackedBarChart && state.globalDisplayState.showPercentage;

      return {
        ...state,
        enablePercentage,
        showPercentage,
        enableStack: stackedBarChart && state.globalDisplayState.enableStack && !showPercentage,
        segmentations: a.segmentations,
      };
    })
    .with({ type: 'toggle-labels' }, () => {
      return {
        ...state,
        showLabels: !state.showLabels,
      };
    })
    .with({ type: 'toggle-percentage' }, () => {
      const showPercentage = stackedBarChart && !state.showPercentage;
      return {
        ...state,
        showPercentage,
        enableStack: !showPercentage,
      };
    })
    .with({ type: 'toggle-stack' }, () => {
      const showStack = stackedBarChart && !state.showStack;
      return {
        ...state,
        showStack,
        enablePercentage: !showStack,
      };
    })
    .with({ type: 'toggle-legend' }, () => {
      return {
        ...state,
        showLegend: !state.showLegend,
      };
    })
    .with({ type: 'toggle-tableView' }, () => {
      const enablePercentage = stackedBarChart && (state.prevState?.enablePercentage ?? false) && state.showTableView;
      const enableStack = stackedBarChart && (state.prevState?.enableStack ?? false) && state.showTableView;
      const enableLabels = state.showTableView;
      const enableLegend = state.showTableView;
      return {
        ...state,
        showTableView: !state.showTableView,
        enablePercentage,
        enableStack,
        enableLabels,
        enableLegend,
        prevState: { ...state, prevState: undefined }, // To avoid recursion
      };
    })
    .with({ type: 'global-state-sync' }, (a) => {
      const showPercentage = stackedBarChart && a.globalDisplayState.showPercentage;
      const showStack = stackedBarChart && a.globalDisplayState.showStack && !showPercentage;
      const enablePercentage = stackedBarChart && a.globalDisplayState.enablePercentage;
      const enableStack = stackedBarChart && a.globalDisplayState.enableStack && !showPercentage;

      return {
        ...state,
        showLabels: a.globalDisplayState.showLabels,
        showPercentage,
        showStack,
        showLegend: a.globalDisplayState.showLegend,
        enablePercentage,
        enableStack,
        globalDisplayState: a.globalDisplayState,
      };
    })
    .exhaustive();
};

export const useMetricQueries = <T extends LineSeriesType | BarSeriesType>(
  queries: LineChartQueryConfig[] | OverTimeBarChartQueryConfig[],
  timeSliderState: TimeSliderState | null,
  filtersState: ItemsState<Filter>,
  segmentationState: ItemsState<Segment>,
  manualDataRefetchCount: number,
  chartTypeConfig: LineChartConfig | AreaChartConfig | OverTimeBarChartConfig | OverTimeTableConfig,
  displayGlobalState: GlobalDisplayState,
  formatLabel: (value: (string | null)[], dataField: DataFieldWithDataType) => string,
  colors: Colors | undefined,
  t: TFunction,
  processData: (
    resultOrErrors: Array<MetricResult | Error>,
    xAxisId: number,
    isBenchmark: boolean | undefined,
    segmentations: Segmentation[] | undefined,
    displayState: DisplayState,
    metricDetailsMap: Record<RegularMetricId | CohortMetricId | EmployeeCohortMetricId, MetricDetails>,
    formatLabel: (value: (string | null)[], datafield: DataFieldWithDataType) => string,
    getAliasForMetric: (metricGroupId: MetricGroupId) => string | null,
    locale: Languages,
    colors: Colors | undefined,
    t: TFunction
  ) => MuiOverTimeChartData<T>
): UseMetricQueriesReturnType<T> => {
  const metricDetailsMap = useMetricDetailsMapContext();
  const aliasService = useAliasServiceContext();
  const { data: aliases } = useAliases(aliasService);
  const locale = useGlobalLocaleContext();
  const getAliasForMetric = aliasService.getAliasForMetricGroupId(aliases, locale.selected);

  const metricService = useMetricServiceContext();

  const queryConfigs = queries.reduce<QueryConfig[]>((acc, query) => {
    const { metrics, timeSelection, filters, segmentations, xAxisId, isBenchmark } = query;

    const finalTimeSelection: TimeSelection = timeSliderState
      ? timeSliderStateToOverTimeTimeSelection(timeSliderState)
      : timeSelection;
    const metricType = getMetricTypeFromMetrics(metrics);
    // Currently, we are assuming that all metrics in the query are of the same type
    // Even without this code below, that implicit assumption is being made
    // TODO: Improve typing to avoid need for such checks
    const firstMetric = metrics[0];
    const needsEmptyCohortFilter = metricIdNeedsEmptyCohortFilter(firstMetric.value as unknown as MetricId);

    const allSegmentations = mergeSegmentations(
      segmentations,
      toSegmentation(segmentationState.items, metricType),
      metrics,
      1
    );

    const key: QueryKey = [
      JSON.stringify([
        'over-time',
        metrics,
        finalTimeSelection,
        filtersState.items,
        filters,
        allSegmentations,
        manualDataRefetchCount,
      ]),
    ];
    // TODO: @Alex: can we assume metrics.length is always non zero
    // Also assuming here that all type(regular, cohort or empCohort) for all
    // metrics within the same query is the same. Need to enforce this somehow later.
    // For now, I am doing this validateMetrics step just for peace of mind
    const validateMetrics = (metrics: MetricIdTypeArrays) => {
      if (!metrics.length) {
        throw new Error('Error: query.metrics array is empty.');
      }
      const uniqueMetricTypes = removeDuplicates(metrics.map((m) => m.type));
      if (uniqueMetricTypes.length > 1) {
        throw new Error(`Error: query.metrics contains metrics of different types - ${uniqueMetricTypes.join(' ,')}`);
      }
    };
    validateMetrics(metrics);
    const sqlFiltersFromFilterState = getFilterConditions(filtersState.items, metricType, needsEmptyCohortFilter);
    const allFilters = [sqlFiltersFromFilterState, filters].flatMap((f) => (f ? [f] : [])).join(' AND ') || undefined;
    const queryFn = () =>
      toMetricResultOrError(metricService, metrics, finalTimeSelection, allFilters, allSegmentations);
    return [
      ...acc,
      {
        queryKey: key,
        queryFn,
        metrics,
        segmentations: allSegmentations ?? undefined,
        filters: allFilters,
        timeSelection: finalTimeSelection,
        xAxisId,
        isBenchmark,
      },
    ];
  }, []);

  const totalSegmentations = queryConfigs.reduce<Segmentation[] | undefined>(
    (acc, qc) => mergeSegmentations(acc, qc.segmentations, qc.metrics, 2),
    undefined
  );
  const totalMetrics = queryConfigs.flatMap<MetricIdType>((qc) => qc.metrics);
  const totalFilters = queryConfigs
    .flatMap((qc) => qc.filters)
    .filter((f) => !!f)
    .join(' AND ');

  const displayHandle = useDisplay(chartTypeConfig, totalSegmentations, totalMetrics, displayGlobalState);
  const [displayState] = displayHandle;

  const [queriesToFetch] = useDebounce(
    queryConfigs.map((qc) => ({ queryKey: qc.queryKey, queryFn: qc.queryFn })),
    500,
    {
      equalityFn: (prev, next) =>
        isEqual(
          prev.map((p) => p.queryKey),
          next.map((n) => n.queryKey)
        ),
    }
  );
  const [queriesToFetchWithDisplayState] = useDebounce(JSON.stringify([...queryConfigs.map((qc) => qc.queryKey)]), 500);

  const { data: results, error } = useSuspenseQueries({
    queries: queriesToFetch,
    combine: (results) => {
      return {
        data: results,
        error: results.flatMap((r) => (r.error ? [r.error] : [])),
      };
    },
  });
  const { data } = useSuspenseQuery({
    queryKey: [
      queriesToFetchWithDisplayState,
      displayState.showStack,
      displayState.showPercentage,
      displayState.showLabels,
    ],
    queryFn: () => {
      const tableViewDataRowWiseMonoid = new NonSegmentedOverTimeTableDataMonoid();
      const tableViewDataColumnWiseMonoid = new NonSegmentedOverTimeTableDataColumnWiseMonoid();
      return results.reduce<MuiOverTimeChartData<T>>((acc, r, i) => {
        const query = queryConfigs[i];
        const tableViewDataMonoid = query.isBenchmark ? tableViewDataColumnWiseMonoid : tableViewDataRowWiseMonoid;

        const muiChartsData: MuiOverTimeChartData<T> = processData(
          r.data,
          query.xAxisId ?? 0,
          query.isBenchmark,
          query.segmentations,
          displayState,
          metricDetailsMap,
          formatLabel,
          getAliasForMetric,
          locale.selected,
          colors,
          t
        );
        // TODO: This works only for non segmented which isn't good
        const combinedTableViewData = tableViewDataMonoid.combine(
          (acc.tableViewData ?? tableViewDataMonoid.empty) as NonSegmentedTableOverTimeData<MetricIdType>,
          muiChartsData.tableViewData as NonSegmentedTableOverTimeData<MetricIdType>
        );
        const combinedXAxisData = [...(acc.xAxisConfig ?? []), ...(muiChartsData.xAxisConfig ?? [])];
        return {
          ...acc,
          series: [...(acc.series ?? []), ...muiChartsData.series],
          tableViewData: combinedTableViewData,
          barLabelConfig: muiChartsData.barLabelConfig,
          metricSql: [...(acc.metricSql ?? []), ...muiChartsData.metricSql],
          // Should make sure this yAxis and xAxis stuff doesn't impact other charts on AlexDB
          yAxisConfig: removeDuplicates([...(acc.yAxisConfig ?? []), ...(muiChartsData.yAxisConfig ?? [])]),
          xAxisConfig: combinedXAxisData,
        };
      }, {} as MuiOverTimeChartData<T>);
    },
  });
  return {
    data,
    displayHandle,
    metrics: totalMetrics,
    segmentations: totalSegmentations,
    filters: totalFilters,
    timeSelection: queryConfigs[0].timeSelection,
    error,
    // ...queryConfigs,
  };
};
