import {
  Box,
  Button,
  ButtonGroup,
  FormControlLabel,
  Typography,
} from '@mui/material'

import { AggregatedTableContext } from '../../../context/AggregatedTableContext'
import type { ButtonProps } from '@mui/material/Button'
import { ConfigContext } from '../../../context/ConfigContext'
import React from 'react'
import { SecondaryTableContext } from '../../../context/SecondaryTableContext'
import {
  ALL_METRICS_ORDERED,
  COUNT_LOGOS_VIEW_DISABLED_METRICS,
} from '../../BridgeTable/utils'
import { compact } from 'lodash'
import { Metric } from '../../../types'
import { SetFilter } from 'ag-grid-enterprise'

export const AggregationSwitchPanel: React.FC<{
  applyAggregation: (agg: string) => void
}> = ({ applyAggregation }) => {
  const { config } = React.useContext(ConfigContext)

  const { gridRef } = React.useContext(AggregatedTableContext)
  const { aggregationType, setAggregationType, setChartData } =
    React.useContext(SecondaryTableContext)
  const selected: ButtonProps = {
    variant: 'contained',
  }

  const switchAgg = React.useCallback(() => {
    const next = aggregationType === 'sum' ? 'count' : 'sum'
    setAggregationType(next)
    setChartData(null)
    applyAggregation(next)

    const metricFilter = gridRef?.current?.api?.getFilterInstance(
      'metric',
    ) as SetFilter<unknown>

    if (metricFilter) {
      metricFilter?.refreshFilterValues()

      if (next === 'count') {
        const values = metricFilter.getValues()

        const newValues = compact(
          (values.length ? values : ALL_METRICS_ORDERED).filter(
            (i) => !COUNT_LOGOS_VIEW_DISABLED_METRICS.includes(i as Metric),
          ),
        )

        metricFilter
          .setModel({
            filterType: 'set',
            values: newValues,
          })
          .then(() => {
            metricFilter.applyModel()
            gridRef?.current?.api?.onFilterChanged()
          })
      } else {
        const resetMetricFilter = () => {
          const metrics =
            aggregationType === 'sum'
              ? ALL_METRICS_ORDERED.filter(
                  (i) => !COUNT_LOGOS_VIEW_DISABLED_METRICS.includes(i),
                )
              : ALL_METRICS_ORDERED
          metricFilter
            .setModel({ filterType: 'set', values: metrics })
            .then(() => {
              metricFilter.applyModel()
              gridRef?.current?.api?.onFilterChanged()
            })
        }
        resetMetricFilter()
      }
    }
  }, [
    applyAggregation,
    setAggregationType,
    setChartData,
    aggregationType,
    gridRef,
  ])

  // Automatically mark all metrics as selected by default
  React.useEffect(() => {
    const metricFilter = gridRef?.current?.api?.getFilterInstance(
      'metric',
    ) as SetFilter<unknown>

    if (!metricFilter) {
      return
    }

    const values = metricFilter.getValues()
    const metrics =
      aggregationType === 'count'
        ? ALL_METRICS_ORDERED.filter(
            (i) => !COUNT_LOGOS_VIEW_DISABLED_METRICS.includes(i),
          )
        : ALL_METRICS_ORDERED

    if (!values?.length) {
      metricFilter.setModel({ filterType: 'set', values: metrics }).then(() => {
        metricFilter.applyModel()
        gridRef?.current?.api?.onFilterChanged()
      })
    }
  }, [gridRef, aggregationType])

  // sync aggregation types from URL and ag-grid
  React.useEffect(() => {
    const agAggregationType = gridRef?.current?.columnApi
      ?.getColumn('amount')
      ?.getAggFunc()
    if (aggregationType !== agAggregationType) {
      applyAggregation(aggregationType)
    }
  }, [aggregationType, applyAggregation, gridRef])

  if (!config?.features_dict.toggle_count_dollar?.enabled) {
    // Ag Grid doesn't support null as return for react component
    return <span />
  }

  return (
    <Box pb={1} pt={1} ml={3}>
      <FormControlLabel
        labelPlacement="start"
        control={
          <Box ml={1}>
            <ButtonGroup
              disableElevation
              variant="outlined"
              size="small"
              color="primary"
            >
              <Button
                {...(aggregationType === 'sum' ? selected : {})}
                onClick={switchAgg}
              >
                $
              </Button>
              <Button
                {...(aggregationType === 'count' ? selected : {})}
                onClick={switchAgg}
              >
                #
              </Button>
            </ButtonGroup>
          </Box>
        }
        label={<Typography variant="caption">Aggregation:</Typography>}
      />
    </Box>
  )
}
