package ujson
import scala.annotation.switch
import upickle.core.{ArrVisitor, ObjVisitor, CharBuilder, RenderUtils}

/**
  * A specialized JSON renderer that can render Chars (Chars or Bytes) directly
  * to a [[java.io.Writer]] or [[java.io.OutputStream]]
  *
  * Note that we use an internal `CharBuilder` to buffer the output internally
  * before sending it to [[out]] in batches. This lets us benefit from the high
  * performance and minimal overhead of `CharBuilder` in the fast path of
  * pushing characters, and avoid the synchronization/polymorphism overhead of
  * [[out]] on the fast path. Most [[out]]s would also have performance
  * benefits from receiving data in batches, rather than char by char.
  */
class BaseCharRenderer[T <: upickle.core.CharOps.Output]
                      (out: T,
                       indent: Int = -1,
                       escapeUnicode: Boolean = false) extends JsVisitor[T, T]{
  private[this] val charBuilder = new upickle.core.CharBuilder
  private[this] val unicodeCharBuilder = new upickle.core.CharBuilder()
  def flushCharBuilder() = {
    charBuilder.writeOutToIfLongerThan(out, if (depth == 0) 0 else 1000)
  }

  private[this] var depth: Int = 0

  private[this] var visitingKey = false

  private[this] var commaBuffered = false
  private[this] var quoteBuffered = false

  def flushBuffer() = {
    if (commaBuffered) {
      commaBuffered = false
      charBuilder.append(',')
      renderIndent()
    }
    if (quoteBuffered) {
      quoteBuffered = false
      charBuilder.append('"')
    }
  }

  def visitArray(length: Int, index: Int) = new ArrVisitor[T, T] {
    flushBuffer()
    charBuilder.append('[')

    depth += 1
    renderIndent()
    def subVisitor = BaseCharRenderer.this
    def visitValue(v: T, index: Int): Unit = {
      flushBuffer()
      commaBuffered = true
    }
    def visitEnd(index: Int) = {
      commaBuffered = false
      depth -= 1
      renderIndent()
      charBuilder.append(']')
      flushCharBuilder()
      out
    }
  }

  def visitJsonableObject(length: Int, index: Int) = new ObjVisitor[T, T] {
    flushBuffer()
    charBuilder.append('{')
    depth += 1
    renderIndent()
    def subVisitor = BaseCharRenderer.this
    def visitKey(index: Int) = {
      quoteBuffered = true
      visitingKey = true
      BaseCharRenderer.this
    }
    def visitKeyValue(s: Any): Unit = {
      charBuilder.append('"')
      visitingKey = false
      charBuilder.append(':')
      if (indent != -1) charBuilder.append(' ')
    }
    def visitValue(v: T, index: Int): Unit = {
      commaBuffered = true
    }
    def visitEnd(index: Int) = {
      commaBuffered = false
      depth -= 1
      renderIndent()
      charBuilder.append('}')
      flushCharBuilder()
      out
    }
  }

  def visitNull(index: Int) = {
    flushBuffer()
    BaseCharRenderer.appendNull(charBuilder)
    flushCharBuilder()
    out
  }

  def visitFalse(index: Int) = {
    flushBuffer()
    BaseCharRenderer.appendFalse(charBuilder)
    flushCharBuilder()
    out
  }

  def visitTrue(index: Int) = {
    flushBuffer()
    BaseCharRenderer.appendTrue(charBuilder)
    flushCharBuilder()
    out
  }

  def visitFloat64StringParts(s: CharSequence, decIndex: Int, expIndex: Int, index: Int) = {
    flushBuffer()
    BaseCharRenderer.appendKnownAsciiString(charBuilder, s)
    flushCharBuilder()
    out
  }

  override def visitFloat32(d: Float, index: Int) = {
    d match{
      case Float.PositiveInfinity => visitNonNullString("Infinity", -1)
      case Float.NegativeInfinity => visitNonNullString("-Infinity", -1)
      case d if java.lang.Float.isNaN(d) => visitNonNullString("NaN", -1)
      case d =>
        // Ensure that for whole numbers that can be exactly represented by an
        // int or long, write them in int notation with decimal points or exponents
        val i = d.toInt
        if (d == i) visitInt32(i, index)
        else {
          val i = d.toLong
          flushBuffer()
          if (i == d) BaseCharRenderer.appendKnownAsciiString(charBuilder, d.toString)
          else {
            charBuilder.ensureLength(15)
            charBuilder.length += ujson.FloatToDecimalChar.toString(charBuilder.arr, charBuilder.length, d)
          }
          flushCharBuilder()
        }
    }
    out
  }

  override def visitFloat64(d: Double, index: Int) = {
    d match{
      case Double.PositiveInfinity => visitNonNullString("Infinity", -1)
      case Double.NegativeInfinity => visitNonNullString("-Infinity", -1)
      case d if java.lang.Double.isNaN(d) => visitNonNullString("NaN", -1)
      case d =>
        // Ensure that for whole numbers that can be exactly represented by an
        // int or long, write them in int notation with decimal points or exponents
        val i = d.toInt
        if (d == i) visitInt32(i, index)
        else {
          val i = d.toLong
          flushBuffer()
          if (i == d) BaseCharRenderer.appendKnownAsciiString(charBuilder, i.toString)
          else {
            charBuilder.ensureLength(24)
            charBuilder.length += ujson.DoubleToDecimalChar.toString(charBuilder.arr, charBuilder.length, d)
          }
          flushCharBuilder()
        }
    }
    out
  }

  override def visitInt32(i: Int, index: Int) = {
    flushBuffer()
    BaseCharRenderer.appendIntString(charBuilder, i)
    flushCharBuilder()
    out
  }

  override def visitInt64(i: Long, index: Int) = {
    flushBuffer()
    if (math.abs(i) > 9007199254740992L /*math.pow(2, 53)*/ ||
        i == -9223372036854775808L /*Long.MinValue*/ ) {
      charBuilder.append('"')
      BaseCharRenderer.appendLongString(charBuilder, i)
      charBuilder.append('"')
    } else BaseCharRenderer.appendLongString(charBuilder, i)
    flushCharBuilder()
    out
  }

  override def visitUInt64(i: Long, index: Int) = {
    val int = i.toInt
    if (int == i) visitInt32(int, index)
    else super.visitUInt64(i, index)
    out
  }

  def visitString(s: CharSequence, index: Int) = {

    if (s eq null) visitNull(index)
    else visitNonNullString(s, index)
  }

  def visitNonNullString(s: CharSequence, index: Int) = {
    flushBuffer()

    upickle.core.RenderUtils.escapeChar(
      unicodeCharBuilder, charBuilder, s, escapeUnicode, !visitingKey
    )

    flushCharBuilder()
    out
  }

  final def renderIndent() = {
    if (indent == -1) ()
    else {
      var i = indent * depth
      charBuilder.ensureLength(i + 1)
      BaseCharRenderer.renderIdent(charBuilder.arr, charBuilder.length, i)
      charBuilder.length += i + 1
    }
  }
}

object BaseCharRenderer{
  private def renderIdent(arr: Array[Char], length: Int, i0: Int) = {
    var i = i0
    arr(length) = '\n'
    while (i > 0) {
      arr(length + i) = ' '
      i -= 1
    }
  }

  private def appendIntString(eb: CharBuilder, i0: Int) = {
    val size = RenderUtils.intStringSize(i0)
    val newLength = eb.length + size
    eb.ensureLength(size)
    appendIntString0(i0, newLength, eb.arr)
    eb.length = newLength
  }

  private def appendIntString0(i0: Int, index: Int, arr: Array[Char]) = {
    // Copied from java.lang.Integer.getChars
    var i = i0
    var q = 0
    var r = 0
    var charPos = index
    val negative = i < 0
    if (!negative) i = -i
    // Generate two digits per iteration
    while (i <= -100) {
      q = i / 100
      r = (q * 100) - i
      i = q
      charPos -= 1
      arr(charPos) = DigitOnes(r)
      charPos -= 1
      arr(charPos) = DigitTens(r)
    }
    // We know there are at most two digits left at this point.
    q = i / 10
    r = (q * 10) - i
    charPos -= 1
    arr(charPos) = ('0' + r).toChar
    // Whatever left is the remaining digit.
    if (q < 0) {
      charPos -= 1
      arr(charPos) = ('0' - q).toChar
    }
    if (negative) {
      charPos -= 1;
      arr(charPos) = '-'.toChar
    }
    charPos
  }

  private def appendLongString(eb: CharBuilder, i0: Long) = {
    val size = RenderUtils.longStringSize(i0)
    val newLength = eb.length + size
    eb.ensureLength(size)
    appendLongString0(i0, newLength, eb.arr)
    eb.length = newLength
  }

  private def appendLongString0(i0: Long, index: Int, buf: Array[Char]) = {
    // Copied from java.lang.Long.getChars
    var i = i0
    var q = 0L
    var r = 0
    var charPos = index
    val negative = i < 0
    if (!negative) i = -i
    // Get 2 digits/iteration using longs until quotient fits into an int
    while (i <= Integer.MIN_VALUE) {
      q = i / 100
      r = ((q * 100) - i).toInt
      i = q
      charPos -= 1
      buf(charPos) = DigitOnes(r)
      charPos -= 1
      buf(charPos) = DigitTens(r)
    }
    // Get 2 digits/iteration using ints
    var q2 = 0
    var i2 = i.toInt
    while (i2 <= -100) {
      q2 = i2 / 100
      r = (q2 * 100) - i2
      i2 = q2
      charPos -= 1;
      buf(charPos) = DigitOnes(r)
      charPos -= 1;
      buf(charPos) = DigitTens(r)
    }
    // We know there are at most two digits left at this point.
    q2 = i2 / 10
    r = (q2 * 10) - i2
    charPos -= 1
    buf(charPos) = ('0' + r).toChar
    // Whatever left is the remaining digit.
    if (q2 < 0) {
      charPos -= 1
      buf(charPos) = ('0' - q2).toChar
    }
    if (negative) {
      charPos -= 1
      buf(charPos) = '-'.toChar
    }
    charPos
  }

  private val DigitTens = Array[Char](
    '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
    '1', '1', '1', '1', '1', '1', '1', '1', '1', '1',
    '2', '2', '2', '2', '2', '2', '2', '2', '2', '2',
    '3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
    '4', '4', '4', '4', '4', '4', '4', '4', '4', '4',
    '5', '5', '5', '5', '5', '5', '5', '5', '5', '5',
    '6', '6', '6', '6', '6', '6', '6', '6', '6', '6',
    '7', '7', '7', '7', '7', '7', '7', '7', '7', '7',
    '8', '8', '8', '8', '8', '8', '8', '8', '8', '8',
    '9', '9', '9', '9', '9', '9', '9', '9', '9', '9',
  )

  private val DigitOnes = Array[Char](
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
  )

  private def appendNull(eb: CharBuilder) = {
    eb.ensureLength(4)
    appendNull0(eb.arr, eb.length)
    eb.length += 4
  }

  private def appendNull0(arr: Array[Char], arrOffset: Int) = {
    arr(arrOffset) = 'n'.toChar
    arr(arrOffset + 1) = 'u'.toChar
    arr(arrOffset + 2) = 'l'.toChar
    arr(arrOffset + 3) = 'l'.toChar
  }

  private def appendTrue(eb: CharBuilder) = {
    eb.ensureLength(4)
    appendTrue0(eb.arr, eb.length)
    eb.length += 4
  }

  private def appendTrue0(arr: Array[Char], arrOffset: Int) = {
    arr(arrOffset) = 't'.toChar
    arr(arrOffset + 1) = 'r'.toChar
    arr(arrOffset + 2) = 'u'.toChar
    arr(arrOffset + 3) = 'e'.toChar
  }

  private def appendFalse(eb: CharBuilder) = {
    eb.ensureLength(5)
    appendFalse0(eb.arr, eb.length)
    eb.length += 5
  }

  private def appendFalse0(arr: Array[Char], arrOffset: Int) = {
    arr(arrOffset) = 'f'.toChar
    arr(arrOffset + 1) = 'a'.toChar
    arr(arrOffset + 2) = 'l'.toChar
    arr(arrOffset + 3) = 's'.toChar
    arr(arrOffset + 4) = 'e'.toChar
  }

  private def appendKnownAsciiString(eb: CharBuilder, s: CharSequence) = {
    val sLength = s.length
    eb.ensureLength(sLength)
    appendKnownAsciiString0(eb.arr, eb.length, s, sLength)

    eb.length += sLength
  }

  private def appendKnownAsciiString0(arr: Array[Char], arrOffset: Int, s: CharSequence, sLength: Int) = {
    var i = 0
    while (i < sLength) {
      arr(arrOffset + i) = s.charAt(i).toChar
      i += 1
    }
  }
}