import { WriteDocument } from 'application/document'
import { NodeSizeStateMap } from '../node/map'
import {
  getGap,
  getLayoutDirection,
  getPadding,
} from 'application/layout/utils'
import { ReadOnlyNode } from 'application/node'
import { NodeSizeState } from '../node/node'
import { truncate } from 'application/math'
import {
  computeMinContent,
  divideChildren,
  filterInLayoutChildren,
  sumChildrenFlexBase,
} from '../utils'

export class FlexMainSize {
  private document: WriteDocument
  private sizeMap: NodeSizeStateMap

  constructor(document: WriteDocument, sizeMap: NodeSizeStateMap) {
    this.document = document
    this.sizeMap = sizeMap
  }

  calculateMainSize = (id: string): void => {
    const node = this.document.getNode(id)
    if (!node) return

    const size = this.sizeMap.get(id)
    if (!size) return

    const children = filterInLayoutChildren(
      node.getChildren() || [],
      this.document
    )
    if (children.length === 0) return

    const direction = getLayoutDirection(node)
    const main = direction === 'column' ? 'h' : 'w'
    const innerSize = this.getContainerInnerSize(node, size)

    const dividedChildren = divideChildren(
      node,
      this.document,
      this.sizeMap,
      innerSize
    )

    for (const subset of dividedChildren) {
      const gap = getGap(node)
      const totalGap = gap * (subset.length - 1)
      const totalMainSize = this.getHypoMainSize(subset)
      const availableSpace = innerSize - totalGap

      if (availableSpace - totalMainSize > 0) {
        this.distributePositiveSpace(subset, availableSpace, main)
      } else if (availableSpace - totalMainSize < 0) {
        this.distributeNegativeSpace(subset, availableSpace, main)
      } else {
        this.setFlexChildrenMainSize(subset, main)
      }
    }
  }

  private distributePositiveSpace = (
    children: string[],
    space: number,
    main: 'w' | 'h'
  ): void => {
    const frozen = new Set<string>()
    const targets: { [key: string]: number } = {}
    for (const childId of children) {
      const child = this.document.getNode(childId)
      if (!child) continue

      const childSize = this.sizeMap.get(childId)
      if (!childSize) continue

      const hypo = childSize.getHypoMainSize()
      const base = childSize.getFlexBaseSize()
      if (hypo === undefined || base === undefined) continue
      if (base > hypo) {
        frozen.add(childId)
        targets[childId] = hypo
        continue
      }

      const grow = child.getStyleAttribute('flex.grow')
      if (grow === undefined || grow === 0) {
        frozen.add(childId)
        targets[childId] = hypo
        continue
      }

      targets[childId] = base
    }

    let initialFreeSpace = space
    for (const childId of children) {
      const target = targets[childId]
      if (target === undefined) continue
      initialFreeSpace -= target
    }

    let remainingChildren = children.filter((childId) => !frozen.has(childId))
    let remainingFreeSpace = initialFreeSpace
    while (remainingFreeSpace > 1 && remainingChildren.length > 0) {
      const growValue = remainingFreeSpace / remainingChildren.length

      for (const childId of remainingChildren) {
        const childSize = this.sizeMap.get(childId)
        if (!childSize) continue

        targets[childId] += growValue
        switch (main) {
          case 'w':
            const maxW = childSize.getMaxW()
            if (maxW !== undefined && targets[childId] > maxW) {
              targets[childId] = maxW
              frozen.add(childId)
            }
            break
          case 'h':
            const maxH = childSize.getMaxH()
            if (maxH !== undefined && targets[childId] > maxH) {
              targets[childId] = maxH
              frozen.add(childId)
            }
            break
        }
      }

      const usedSpace = Object.values(targets).reduce(
        (acc, target) => acc + target
      )
      remainingFreeSpace = space - usedSpace
      remainingChildren = children.filter((childId) => !frozen.has(childId))
    }

    for (const childId of children) {
      const childSize = this.sizeMap.get(childId)
      if (!childSize) continue

      const target = targets[childId]
      if (target === undefined) continue

      switch (main) {
        case 'w':
          childSize.setW(target, true)
          break
        case 'h':
          childSize.setH(target, true)
          break
      }
    }
  }

  private distributeNegativeSpace = (
    children: string[],
    space: number,
    main: 'w' | 'h'
  ): void => {
    const frozen = new Set<string>()
    const targets: { [key: string]: number } = {}

    for (const childId of children) {
      const child = this.document.getNode(childId)
      if (!child) continue

      const childSize = this.sizeMap.get(childId)
      if (!childSize) continue

      const hypo = childSize.getHypoMainSize()
      const base = childSize.getFlexBaseSize()
      if (hypo === undefined || base === undefined) continue
      if (base < hypo) {
        frozen.add(childId)
        targets[childId] = hypo
        continue
      }

      const shrink = child.getStyleAttribute('flex.shrink')
      if (shrink === undefined || shrink === 0) {
        frozen.add(childId)
        targets[childId] = hypo
        continue
      }

      targets[childId] = base
    }

    let initialExcessSpace = 0
    for (const childId of children) {
      const target = targets[childId]
      if (target === undefined) continue
      initialExcessSpace += target
    }
    initialExcessSpace -= space

    let remainingChildren = children.filter((childId) => !frozen.has(childId))
    let remainingExcessSpace = initialExcessSpace
    let totalRemaingShrinkFactor = remainingChildren.reduce(
      (acc, childId) => acc + (this.getShrinkFactor(childId, main) || 0),
      0
    )
    while (
      remainingExcessSpace > 1 &&
      remainingChildren.length > 0 &&
      totalRemaingShrinkFactor > 0
    ) {
      for (const childId of remainingChildren) {
        const child = this.document.getNode(childId)
        if (!child) continue

        const childSize = this.sizeMap.get(childId)
        if (!childSize) continue

        const shrink = this.getShrinkFactor(childId, main)
        if (shrink === 0) continue

        const ratio = shrink / totalRemaingShrinkFactor
        targets[childId] -= Math.abs(truncate(remainingExcessSpace * ratio, 3))

        switch (main) {
          case 'w':
            const minContentW = computeMinContent(
              childId,
              this.sizeMap,
              this.document,
              'w'
            )
            if (targets[childId] < minContentW) {
              targets[childId] = minContentW
              frozen.add(childId)
            }
            break
          case 'h':
            const minContentH = computeMinContent(
              childId,
              this.sizeMap,
              this.document,
              'h'
            )
            if (targets[childId] < minContentH) {
              targets[childId] = minContentH
              frozen.add(childId)
            }
            break
        }
      }

      const usedSpace = Object.values(targets).reduce(
        (acc, target) => acc + target
      )
      remainingExcessSpace = usedSpace - space
      remainingChildren = children.filter((childId) => !frozen.has(childId))
      totalRemaingShrinkFactor = remainingChildren.reduce(
        (acc, childId) => acc + (this.getShrinkFactor(childId, main) || 0),
        0
      )
    }

    for (const childId of children) {
      const childSize = this.sizeMap.get(childId)
      if (!childSize) continue

      const target = targets[childId]
      if (target === undefined) continue

      switch (main) {
        case 'w':
          childSize.setW(target, true)
          break
        case 'h':
          childSize.setH(target, true)
          break
      }
    }
  }

  private setFlexChildrenMainSize = (
    children: string[],
    main: 'w' | 'h'
  ): void => {
    for (const childId of children) {
      const childSize = this.sizeMap.get(childId)
      if (!childSize) continue

      const hypoMainSize = childSize.getHypoMainSize()
      if (hypoMainSize === undefined) continue

      switch (main) {
        case 'w':
          childSize.setW(hypoMainSize, true)
          break
        case 'h':
          childSize.setH(hypoMainSize, true)
          break
      }
    }
  }

  private getHypoMainSize = (children: string[]): number => {
    return children.reduce(
      (acc, childId) =>
        acc + (this.sizeMap.get(childId)?.getHypoMainSize() || 0),
      0
    )
  }

  private getContainerInnerSize = (
    node: ReadOnlyNode,
    size: NodeSizeState
  ): number => {
    const direction = getLayoutDirection(node)
    switch (direction) {
      case 'row':
      case 'wrap':
        const innerW = size.getInnerW()
        if (innerW !== undefined) return innerW
        return sumChildrenFlexBase(node, this.sizeMap, this.document, 'w')
      case 'column':
        const innerH = size.getInnerH()
        if (innerH !== undefined) return innerH
        return sumChildrenFlexBase(node, this.sizeMap, this.document, 'h')
    }
  }

  private getShrinkFactor = (id: string, main: 'w' | 'h'): number => {
    const node = this.document.getNode(id)
    if (!node) return 0

    const size = this.sizeMap.get(id)
    if (!size) return 0

    const base = size.getFlexBaseSize()
    if (base === undefined) return 0

    switch (main) {
      case 'w':
        const left = getPadding(node, 'left')
        const right = getPadding(node, 'right')
        return base - left - right
      case 'h':
        const top = getPadding(node, 'top')
        const bottom = getPadding(node, 'bottom')
        return base - top - bottom
    }
  }
}
