#include "solver/puzzle.h"

#include <algorithm>
#include <cassert>
#include <iostream>
#include <sstream>

namespace {

uint8_t RowStart(uint8_t id) { return (id / 9) * 9; }
uint8_t ColStart(uint8_t id) { return id % 9; }
uint8_t BoxStart(uint8_t id) {
  uint8_t row = (RowStart(id) / 27) * 27;
  uint8_t col = (ColStart(id) / 3) * 3;
  return row + col;
}

constexpr std::array<uint8_t, 9> kBoxStarts{0, 3, 6, 27, 30, 33, 54, 57, 60};
constexpr std::array<uint8_t, 9> kBoxOffsets{0, 1, 2, 9, 10, 11, 18, 19, 20};

}  // namespace

Puzzle Puzzle::FromString(std::string puzzle) {
  assert(puzzle.length() == 81);
  Puzzle p;

  for (int i = 0; i < 81; i++) {
    if (puzzle[i] == '.' || puzzle[i] == '0') {
      continue;
    }
    int diff = puzzle[i] - '0';
    if (diff < 1 || diff > 9) {
      assert(false && "Invalid input character");
    }
    p.AssignSquare(i, diff);
  }
  return p;
}

std::string Puzzle::CurrentState() {
  std::ostringstream str;
  for (const Cell& c : cells_) {
    if (c.IsSolved()) {
      str << (int)c.value();
    } else {
      str << '.';
    }
  }
  return str.str();
}

std::string Puzzle::PencilMarkState() {
  std::ostringstream str;
  for (const Cell& c : cells_) {
    for (uint8_t i = 1; i <= 9; i++) {
      if (c.IsPossible(i)) {
        str << (int)i;
      }
    }
    str << ",";
  }
  // Erase the trailing ",".
  std::string temp = str.str();
  temp.erase(temp.end() - 1);
  return temp;
}

bool Puzzle::IsSolved() {
  return std::all_of(cells_.begin(), cells_.end(),
                     [](const Cell& c) { return c.IsSolved(); });
}

bool Puzzle::ApplyNextStep() {
  // Search for a naked single.
  for (int i = 0; i < 81; i++) {
    if (cells_[i].IsSolved()) {
      continue;
    }

    if (cells_[i].NumPossibilities() == 1) {
      for (uint8_t v = 1; v <= 9; v++) {
        if (cells_[i].IsPossible(v)) {
          AssignSquare(i, v);
          return true;
        }
      }
    }
  }
  // Search for a hidden single in a box.
  for (int box = 0; box < 9; box++) {
    uint8_t boxroot = kBoxStarts[box];
    // FIXME: We should be able to check all of the numbers at once.
    for (int n = 1; n <= 9; n++) {
      int8_t found_loc = -1;
      bool exit = false;
      for (int cellind = 0; cellind < 9 && !exit; cellind++) {
        Cell cell = cells_[boxroot + kBoxOffsets[cellind]];
        if (cell.IsSolved()) {
          if (cell.value() == n) {
            // This number is solved in this box, we can exit.
            exit = true;
          }
          continue;
        }
        if (cell.IsPossible(n)) {
          if (found_loc != -1) {
            exit = true;
          }
          found_loc = boxroot + kBoxOffsets[cellind];
        }
      }
      if (found_loc != -1 && !exit) {
        AssignSquare(found_loc, n);
        return true;
      }
    }
  }
  return false;
}

void Puzzle::AssignSquare(uint8_t id, uint8_t value) {
  assert(id < 81);
  assert(value >= 0 && value <= 9);

  cells_[id] = Cell(value);

  const uint8_t row = RowStart(id);
  for (uint8_t i = row; i < row + 9; i++) {
    cells_[i].Restrict(value);
  }

  const uint8_t col = ColStart(id);
  for (uint8_t i = col; i < 81; i += 9) {
    cells_[i].Restrict(value);
  }

  uint8_t box = BoxStart(id);
  for (uint8_t offset : kBoxOffsets) {
    cells_[box + offset].Restrict(value);
  }
}

Puzzle::Puzzle() {}