use super::util;
use std::cmp::{max, min, Ord};
use std::collections::HashSet;
use std::error::Error;
use std::ops::Add;
use std::str::FromStr;

struct Segment<Item> {
    x0: Item,
    y0: Item,
    x1: Item,
    y1: Item,
}

impl<Item> FromStr for Segment<Item>
where
    Item: FromStr + Ord,
    <Item as FromStr>::Err: Error + Send + Sync + 'static,
{
    type Err = Box<dyn Error + Send + Sync>;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let mut iter = s
            .split(" -> ")
            .map(|part| -> Result<(Item, Item), Self::Err> {
                let mut iter = part.split(',');
                let x = iter.next().ok_or(util::Error)?.parse()?;
                let y = iter.next().ok_or(util::Error)?.parse()?;
                if iter.next().is_none() {
                    Ok((x, y))
                } else {
                    Err(util::Error.into())
                }
            });
        let (x0, y0) = iter.next().ok_or(util::Error)??;
        let (x1, y1) = iter.next().ok_or(util::Error)??;
        if iter.next().is_none() {
            Ok(if x0 < x1 || x0 == x1 && y0 < y1 {
                Segment { x0, y0, x1, y1 }
            } else {
                Segment {
                    x0: x1,
                    y0: y1,
                    x1: x0,
                    y1: y0,
                }
            })
        } else {
            Err(util::Error.into())
        }
    }
}

struct SegmentIterator<Item> {
    x: Item,
    y: Item,
    dx: Item,
    dy: Item,
    len: usize,
}

impl<Item> Iterator for SegmentIterator<Item>
where
    Item: Copy + Add<Output = Item>,
{
    type Item = (Item, Item);
    fn next(&mut self) -> Option<Self::Item> {
        self.len = self.len.checked_sub(1)?;
        let item = (self.x, self.y);
        self.x = self.x + self.dx;
        self.y = self.y + self.dy;
        Some(item)
    }
    fn size_hint(&self) -> (usize, Option<usize>) {
        (self.len, Some(self.len))
    }
}

impl Segment<i32> {
    fn intersection(&self, other: &Segment<i32>) -> Option<SegmentIterator<i32>> {
        let ret = if self.x0 == self.x1 && other.x0 == other.x1 {
            if self.x0 != other.x0 {
                return None;
            }
            let y0 = max(self.y0, other.y0);
            let y1 = min(self.y1, other.y1);
            SegmentIterator {
                x: self.x0,
                y: y0,
                dx: 0,
                dy: 1,
                len: usize::try_from(y1 - y0).ok()? + 1,
            }
        } else if self.x0 == self.x1 {
            if !(other.x0..=other.x1).contains(&self.x0) {
                return None;
            }
            let y = (other.y1 - other.y0) / (other.x1 - other.x0) * (self.x0 - other.x0) + other.y0;
            if !(self.y0..=self.y1).contains(&y) {
                return None;
            }
            SegmentIterator {
                x: self.x0,
                y,
                dx: 0,
                dy: 0,
                len: 1,
            }
        } else if other.x0 == other.x1 {
            if !(self.x0..=self.x1).contains(&other.x0) {
                return None;
            }
            let y = (self.y1 - self.y0) / (self.x1 - self.x0) * (other.x0 - self.x0) + self.y0;
            if !(other.y0..=other.y1).contains(&y) {
                return None;
            }
            SegmentIterator {
                x: other.x0,
                y,
                dx: 0,
                dy: 0,
                len: 1,
            }
        } else {
            let m0 = (self.y1 - self.y0) / (self.x1 - self.x0);
            let m1 = (other.y1 - other.y0) / (other.x1 - other.x0);
            let a0 = self.y0 - m0 * self.x0;
            let a1 = other.y0 - m1 * other.x0;
            if m0 == m1 {
                if a0 != a1 {
                    return None;
                }
                let x0 = max(self.x0, other.x0);
                let x1 = min(self.x1, other.x1);
                SegmentIterator {
                    x: x0,
                    y: m0 * x0 + a0,
                    dx: 1,
                    dy: m0,
                    len: usize::try_from(x1 - x0).ok()? + 1,
                }
            } else if (a1 - a0) % (m1 - m0) == 0 {
                let x = -(a1 - a0) / (m1 - m0);
                if !(max(self.x0, other.x0)..=min(self.x1, other.x1)).contains(&x) {
                    return None;
                }
                SegmentIterator {
                    x,
                    y: m0 * x + a0,
                    dx: 0,
                    dy: 0,
                    len: 1,
                }
            } else {
                return None;
            }
        };
        Some(ret)
    }
}

pub fn part1<'a, I, S>(lines: I) -> Result<usize, Box<dyn Error + Send + Sync>>
where
    I: IntoIterator<Item = &'a S>,
    S: AsRef<str> + 'a,
{
    let segments = lines
        .into_iter()
        .map(|line| line.as_ref().parse())
        .filter_map(|res| {
            res.map(|segment: Segment<i32>| {
                if segment.x0 == segment.x1 || segment.y0 == segment.y1 {
                    Some(segment)
                } else {
                    None
                }
            })
            .transpose()
        })
        .collect::<Result<Vec<_>, _>>()?;
    let points = segments
        .iter()
        .enumerate()
        .flat_map(|(i, segment0)| {
            segments[i + 1..]
                .iter()
                .filter_map(|segment1| segment0.intersection(segment1))
        })
        .flatten()
        .collect::<HashSet<_>>();
    Ok(points.len())
}

pub fn part2<'a, I, S>(lines: I) -> Result<usize, Box<dyn Error + Send + Sync>>
where
    I: IntoIterator<Item = &'a S>,
    S: AsRef<str> + 'a,
{
    let segments: Vec<Segment<i32>> = util::parse_many(lines)?;
    let points = segments
        .iter()
        .enumerate()
        .flat_map(|(i, segment0)| {
            segments[i + 1..]
                .iter()
                .filter_map(|segment1| segment0.intersection(segment1))
        })
        .flatten()
        .collect::<HashSet<_>>();
    Ok(points.len())
}

#[cfg(test)]
mod tests {
    use super::*;
    use pretty_assertions::assert_eq;

    static EXAMPLE: &[&str] = &[
        "0,9 -> 5,9",
        "8,0 -> 0,8",
        "9,4 -> 3,4",
        "2,2 -> 2,1",
        "7,0 -> 7,4",
        "6,4 -> 2,0",
        "0,9 -> 2,9",
        "3,4 -> 1,4",
        "0,0 -> 8,8",
        "5,5 -> 8,2",
    ];

    #[test]
    fn part1_examples() -> Result<(), Box<dyn Error + Send + Sync>> {
        assert_eq!(5, part1(EXAMPLE)?);
        Ok(())
    }

    #[test]
    fn part2_examples() -> Result<(), Box<dyn Error + Send + Sync>> {
        assert_eq!(12, part2(EXAMPLE)?);
        Ok(())
    }
}