package ujson
import scala.annotation.switch
import upickle.core.{ArrVisitor, ObjVisitor, ByteBuilder, RenderUtils}

/**
  * A specialized JSON renderer that can render Bytes (Chars or Bytes) directly
  * to a [[java.io.Writer]] or [[java.io.OutputStream]]
  *
  * Note that we use an internal `ByteBuilder` to buffer the output internally
  * before sending it to [[out]] in batches. This lets us benefit from the high
  * performance and minimal overhead of `ByteBuilder` in the fast path of
  * pushing characters, and avoid the synchronization/polymorphism overhead of
  * [[out]] on the fast path. Most [[out]]s would also have performance
  * benefits from receiving data in batches, rather than byte by byte.
  */
class BaseByteRenderer[T <: upickle.core.ByteOps.Output]
                      (out: T,
                       indent: Int = -1,
                       escapeUnicode: Boolean = false) extends JsVisitor[T, T]{
  private[this] val byteBuilder = new upickle.core.ByteBuilder
  private[this] val unicodeCharBuilder = new upickle.core.CharBuilder()
  def flushByteBuilder() = {
    byteBuilder.writeOutToIfLongerThan(out, if (depth == 0) 0 else 1000)
  }

  private[this] var depth: Int = 0

  private[this] var visitingKey = false

  private[this] var commaBuffered = false
  private[this] var indentBuffered = false
  private[this] var quoteBuffered = false

  def flushBuffer() = {
    if (commaBuffered) {
      commaBuffered = false
      byteBuilder.append(',')
    }
    if (indentBuffered){
      indentBuffered = false
      renderIndent()
    }
    if (quoteBuffered) {
      quoteBuffered = false
      byteBuilder.append('"')
    }
  }

  def visitArray(length: Int, index: Int) = new ArrVisitor[T, T] {
    flushBuffer()
    byteBuilder.append('[')

    depth += 1
    indentBuffered = true

    def subVisitor = BaseByteRenderer.this

    def visitValue(v: T, index: Int): Unit = {
      flushBuffer()
      commaBuffered = true
      indentBuffered = true
    }

    def visitEnd(index: Int) = {
      depth -= 1
      if (indentBuffered && commaBuffered) renderIndent()
      commaBuffered = false
      indentBuffered = false
      byteBuilder.append(']')
      flushByteBuilder()
      out
    }
  }

  def visitJsonableObject(length: Int, index: Int) = new ObjVisitor[T, T] {
    flushBuffer()
    byteBuilder.append('{')
    depth += 1
    indentBuffered = true

    def subVisitor = BaseByteRenderer.this
    def visitKey(index: Int) = {
      quoteBuffered = true
      visitingKey = true
      BaseByteRenderer.this
    }

    def visitKeyValue(s: Any): Unit = {
      byteBuilder.append('"')
      visitingKey = false
      byteBuilder.append(':')
      if (indent != -1) byteBuilder.append(' ')
    }

    def visitValue(v: T, index: Int): Unit = {
      commaBuffered = true
      indentBuffered = true
    }

    def visitEnd(index: Int) = {
      depth -= 1
      if (indentBuffered && commaBuffered) renderIndent()
      commaBuffered = false
      indentBuffered = false
      byteBuilder.append('}')
      flushByteBuilder()
      out
    }
  }

  def visitNull(index: Int) = {
    flushBuffer()
    BaseByteRenderer.appendNull(byteBuilder)
    flushByteBuilder()
    out
  }

  def visitFalse(index: Int) = {
    flushBuffer()
    BaseByteRenderer.appendFalse(byteBuilder)
    flushByteBuilder()
    out
  }

  def visitTrue(index: Int) = {
    flushBuffer()
    BaseByteRenderer.appendTrue(byteBuilder)
    flushByteBuilder()
    out
  }

  def visitFloat64StringParts(s: CharSequence, decIndex: Int, expIndex: Int, index: Int) = {
    flushBuffer()
    BaseByteRenderer.appendKnownAsciiString(byteBuilder, s)
    flushByteBuilder()
    out
  }

  override def visitFloat32(d: Float, index: Int) = {
    d match{
      case Float.PositiveInfinity => visitNonNullString("Infinity", -1)
      case Float.NegativeInfinity => visitNonNullString("-Infinity", -1)
      case d if java.lang.Float.isNaN(d) => visitNonNullString("NaN", -1)
      case d =>
        // Ensure that for whole numbers that can be exactly represented by an
        // int or long, write them in int notation with decimal points or exponents
        val i = d.toInt
        if (d == i) visitInt32(i, index)
        else {
          val i = d.toLong
          flushBuffer()
          if (i == d) BaseByteRenderer.appendKnownAsciiString(byteBuilder, d.toString)
          else {
            byteBuilder.ensureLength(15)
            byteBuilder.length += ujson.FloatToDecimalByte.toString(byteBuilder.arr, byteBuilder.length, d)
          }
          flushByteBuilder()
        }
    }
    out
  }

  override def visitFloat64(d: Double, index: Int) = {
    d match{
      case Double.PositiveInfinity => visitNonNullString("Infinity", -1)
      case Double.NegativeInfinity => visitNonNullString("-Infinity", -1)
      case d if java.lang.Double.isNaN(d) => visitNonNullString("NaN", -1)
      case d =>
        // Ensure that for whole numbers that can be exactly represented by an
        // int or long, write them in int notation with decimal points or exponents
        val i = d.toInt
        if (d == i) visitInt32(i, index)
        else {
          val i = d.toLong
          flushBuffer()
          if (i == d) BaseByteRenderer.appendKnownAsciiString(byteBuilder, i.toString)
          else {
            byteBuilder.ensureLength(24)
            byteBuilder.length += ujson.DoubleToDecimalByte.toString(byteBuilder.arr, byteBuilder.length, d)
          }
          flushByteBuilder()
        }
    }
    out
  }

  override def visitInt32(i: Int, index: Int) = {
    flushBuffer()
    BaseByteRenderer.appendIntString(byteBuilder, i)
    flushByteBuilder()
    out
  }

  override def visitInt64(i: Long, index: Int) = {
    flushBuffer()
    if (math.abs(i) > 9007199254740992L /*math.pow(2, 53)*/ ||
        i == -9223372036854775808L /*Long.MinValue*/ ) {
      byteBuilder.append('"')
      BaseByteRenderer.appendLongString(byteBuilder, i)
      byteBuilder.append('"')
    } else BaseByteRenderer.appendLongString(byteBuilder, i)
    flushByteBuilder()
    out
  }

  override def visitUInt64(i: Long, index: Int) = {
    val int = i.toInt
    if (int == i) visitInt32(int, index)
    else super.visitUInt64(i, index)
    out
  }

  def visitString(s: CharSequence, index: Int) = {

    if (s eq null) visitNull(index)
    else visitNonNullString(s, index)
  }

  def visitNonNullString(s: CharSequence, index: Int) = {
    flushBuffer()

    upickle.core.RenderUtils.escapeByte(
      unicodeCharBuilder, byteBuilder, s, escapeUnicode, !visitingKey
    )

    flushByteBuilder()
    out
  }

  final def renderIndent() = {
    if (indent == -1) ()
    else {
      var i = indent * depth
      byteBuilder.ensureLength(i + 1)
      BaseByteRenderer.renderIdent(byteBuilder.arr, byteBuilder.length, i)
      byteBuilder.length += i + 1
    }
  }
}

object BaseByteRenderer{
  private def renderIdent(arr: Array[Byte], length: Int, i0: Int) = {
    var i = i0
    arr(length) = '\n'
    while (i > 0) {
      arr(length + i) = ' '
      i -= 1
    }
  }

  private def appendIntString(eb: ByteBuilder, i0: Int) = {
    val size = RenderUtils.intStringSize(i0)
    val newLength = eb.length + size
    eb.ensureLength(size)
    appendIntString0(i0, newLength, eb.arr)
    eb.length = newLength
  }

  private def appendIntString0(i0: Int, index: Int, arr: Array[Byte]) = {
    // Copied from java.lang.Integer.getChars
    var i = i0
    var q = 0
    var r = 0
    var charPos = index
    val negative = i < 0
    if (!negative) i = -i
    // Generate two digits per iteration
    while (i <= -100) {
      q = i / 100
      r = (q * 100) - i
      i = q
      charPos -= 1
      arr(charPos) = DigitOnes(r)
      charPos -= 1
      arr(charPos) = DigitTens(r)
    }
    // We know there are at most two digits left at this point.
    q = i / 10
    r = (q * 10) - i
    charPos -= 1
    arr(charPos) = ('0' + r).toByte
    // Whatever left is the remaining digit.
    if (q < 0) {
      charPos -= 1
      arr(charPos) = ('0' - q).toByte
    }
    if (negative) {
      charPos -= 1;
      arr(charPos) = '-'.toByte
    }
    charPos
  }

  private def appendLongString(eb: ByteBuilder, i0: Long) = {
    val size = RenderUtils.longStringSize(i0)
    val newLength = eb.length + size
    eb.ensureLength(size)
    appendLongString0(i0, newLength, eb.arr)
    eb.length = newLength
  }

  private def appendLongString0(i0: Long, index: Int, buf: Array[Byte]) = {
    // Copied from java.lang.Long.getChars
    var i = i0
    var q = 0L
    var r = 0
    var charPos = index
    val negative = i < 0
    if (!negative) i = -i
    // Get 2 digits/iteration using longs until quotient fits into an int
    while (i <= Integer.MIN_VALUE) {
      q = i / 100
      r = ((q * 100) - i).toInt
      i = q
      charPos -= 1
      buf(charPos) = DigitOnes(r)
      charPos -= 1
      buf(charPos) = DigitTens(r)
    }
    // Get 2 digits/iteration using ints
    var q2 = 0
    var i2 = i.toInt
    while (i2 <= -100) {
      q2 = i2 / 100
      r = (q2 * 100) - i2
      i2 = q2
      charPos -= 1;
      buf(charPos) = DigitOnes(r)
      charPos -= 1;
      buf(charPos) = DigitTens(r)
    }
    // We know there are at most two digits left at this point.
    q2 = i2 / 10
    r = (q2 * 10) - i2
    charPos -= 1
    buf(charPos) = ('0' + r).toByte
    // Whatever left is the remaining digit.
    if (q2 < 0) {
      charPos -= 1
      buf(charPos) = ('0' - q2).toByte
    }
    if (negative) {
      charPos -= 1
      buf(charPos) = '-'.toByte
    }
    charPos
  }

  private val DigitTens = Array[Byte](
    '0', '0', '0', '0', '0', '0', '0', '0', '0', '0',
    '1', '1', '1', '1', '1', '1', '1', '1', '1', '1',
    '2', '2', '2', '2', '2', '2', '2', '2', '2', '2',
    '3', '3', '3', '3', '3', '3', '3', '3', '3', '3',
    '4', '4', '4', '4', '4', '4', '4', '4', '4', '4',
    '5', '5', '5', '5', '5', '5', '5', '5', '5', '5',
    '6', '6', '6', '6', '6', '6', '6', '6', '6', '6',
    '7', '7', '7', '7', '7', '7', '7', '7', '7', '7',
    '8', '8', '8', '8', '8', '8', '8', '8', '8', '8',
    '9', '9', '9', '9', '9', '9', '9', '9', '9', '9',
  )

  private val DigitOnes = Array[Byte](
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
  )

  private def appendNull(eb: ByteBuilder) = {
    eb.ensureLength(4)
    appendNull0(eb.arr, eb.length)
    eb.length += 4
  }

  private def appendNull0(arr: Array[Byte], arrOffset: Int) = {
    arr(arrOffset) = 'n'.toByte
    arr(arrOffset + 1) = 'u'.toByte
    arr(arrOffset + 2) = 'l'.toByte
    arr(arrOffset + 3) = 'l'.toByte
  }

  private def appendTrue(eb: ByteBuilder) = {
    eb.ensureLength(4)
    appendTrue0(eb.arr, eb.length)
    eb.length += 4
  }

  private def appendTrue0(arr: Array[Byte], arrOffset: Int) = {
    arr(arrOffset) = 't'.toByte
    arr(arrOffset + 1) = 'r'.toByte
    arr(arrOffset + 2) = 'u'.toByte
    arr(arrOffset + 3) = 'e'.toByte
  }

  private def appendFalse(eb: ByteBuilder) = {
    eb.ensureLength(5)
    appendFalse0(eb.arr, eb.length)
    eb.length += 5
  }

  private def appendFalse0(arr: Array[Byte], arrOffset: Int) = {
    arr(arrOffset) = 'f'.toByte
    arr(arrOffset + 1) = 'a'.toByte
    arr(arrOffset + 2) = 'l'.toByte
    arr(arrOffset + 3) = 's'.toByte
    arr(arrOffset + 4) = 'e'.toByte
  }

  private def appendKnownAsciiString(eb: ByteBuilder, s: CharSequence) = {
    val sLength = s.length
    eb.ensureLength(sLength)
    appendKnownAsciiString0(eb.arr, eb.length, s, sLength)

    eb.length += sLength
  }

  private def appendKnownAsciiString0(arr: Array[Byte], arrOffset: Int, s: CharSequence, sLength: Int) = {
    var i = 0
    while (i < sLength) {
      arr(arrOffset + i) = s.charAt(i).toByte
      i += 1
    }
  }
}