// included from expression.cpp!

#include "differentiate.h"
#include "expression.h"

KNode*
KNode::differentiate( QString )
{
  // returning NULL means "can not be differentiated"
  return static_cast<KNode*>( NULL );
} // KNode::differentiate


KNode*
KNodePlus::differentiate( QString var )
{
  // left + right => left' + right'
  KNodePlus* ret = new KNodePlus;
  CHECK_PTR( ret );
  ret->left = left->differentiate( var );
  if ( !ret->left ) {
    delete ret;
    return static_cast<KNode*>( NULL );
  } // if
  ret->right = right->differentiate( var );
  if ( !ret->right ) {
    delete ret;
    return static_cast<KNode*>( NULL );
  } // if

  return ret;
} // KNodePlus::differentiate


KNode*
KNodeMinus::differentiate( QString var )
{
  // left - right => left' - right'
  KNodeMinus* ret = new KNodeMinus;
  CHECK_PTR( ret );
  ret->left = left->differentiate( var );
  if ( !ret->left ) {
    delete ret;
    return static_cast<KNode*>( NULL );
  } // if
  ret->right = right->differentiate( var );
  if ( !ret->right ) {
    delete ret;
    return static_cast<KNode*>( NULL );
  } // if

  return ret;
} // KNodeMinus::differentiate


KNode*
KNodeMult::differentiate( QString var )
{
  // left * right => left' * right + left * right'
  KNodePlus* plus = new KNodePlus;
  CHECK_PTR( plus );

  // left' * right
  KNodeMult* multLeft = new KNodeMult;
  CHECK_PTR( multLeft );
  plus->left = multLeft;
  multLeft->left = left->differentiate( var );
  if( !multLeft->left ) {
    delete plus;
    return static_cast<KNode*>( NULL );
  } // if
  multLeft->right = right->copy();

  // left * right'
  KNodeMult* multRight = new KNodeMult;
  CHECK_PTR( multRight );
  plus->right = multRight;
  multRight->left = left->copy();
  multRight->right = right->differentiate( var );
  if( !multRight ) {
    delete plus;
    return static_cast<KNode*>( NULL );
  } // if

  return plus;
} // KNodeMult::differentiate


KNode*
KNodeDiv::differentiate( QString var )
{
  // left / right => ( left' * right - left * right' ) / right^2
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );

  // left' * right
  KNodeMinus* minus = new KNodeMinus;
  CHECK_PTR( minus );
  div->left = minus;
  KNodeMult* multLeft = new KNodeMult;
  CHECK_PTR( multLeft );
  minus->left = multLeft;
  multLeft->left = left->differentiate( var );
  if ( !multLeft->left ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if
  multLeft->right = right->copy();

  // left * right'
  KNodeMult* multRight = new KNodeMult;
  CHECK_PTR( multRight );
  minus->right = multRight;
  multRight->left = left->copy();
  multRight->right = right->differentiate( var );
  if ( !multRight->right ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if

  // right^2
  KNodePow* pow = new KNodePow;
  CHECK_PTR( pow );
  div->right = pow;
  pow->left = right->copy();
  pow->right = new KNodeValue( 2.0 );
  CHECK_PTR( pow->right );

  return div;
} // KNodeDiv::differentiate


KNode*
KNodePow::differentiate( QString var )
{
  // left^right => left ^ right * ( right' * ln(left) + right * left' / left )

  // left^right
  KNodeMult* mult = new KNodeMult;
  CHECK_PTR( mult );
  mult->left = copy();

  KNodePlus* plus = new KNodePlus;
  CHECK_PTR( plus );

  // right' * ln( left )
  mult->right = plus;
  KNodeMult* mult2 = new KNodeMult;
  CHECK_PTR( mult2 );
  plus->left = mult2;
  mult2->left = right->differentiate( var );
  if( !mult2->left ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  funcCall* fc = (*dictFunctions)[ "ln" ];
  if ( !fc ) {
    warning( "ln unknown!" );
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* nfc = new KNodeFuncCall;
  CHECK_PTR( nfc );
  mult2->right = nfc;
  nfc->function = fc;
  nfc->param[ 0 ] = left->copy();

  // right * left' / left
  KNodeMult* mult3 = new KNodeMult;
  CHECK_PTR( mult3 );
  plus->right = mult3;
  mult3->left = right->copy();
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  mult3->right = div;
  div->left = left->differentiate( var );
  if( !div->left ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  div->right = left->copy();

  return mult;
} // KNodePow::differentiate


KNode*
KNodeValue::differentiate( QString )
{
  KNodeValue* value = new KNodeValue( 0.0 );
  CHECK_PTR( value );
  return value;
} // KNodeValue::differentiate


KNode*
KNodeVar::differentiate( QString var )
{
  KNodeValue* value = new KNodeValue( sVarName == var ? 1.0 : 0.0 );
  CHECK_PTR( value );
  return value;
} // KNodeVar::differentiate


KNode*
KNodeFuncCall::differentiate( QString var )
{
  return function->differentiate( var, this );
} // KNodeFuncCall::differentiate


KNode*
differentiate_sin( const KNodeFuncCall* nfc, QString var )
{
  // sin( n ) => n'*cos( n )
  KNodeMult* mult = new KNodeMult;
  CHECK_PTR( mult );
  mult->left = nfc->param[ 0 ]->differentiate( var );
  if ( !mult->left ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if

  // cos( n )
  funcCall* fc_cos = (*dictFunctions)[ "cos" ];
  if ( !fc_cos ) {
    warning( "cos unknown!" );
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* fc = new KNodeFuncCall;
  CHECK_PTR( fc );
  mult->right = fc;
  fc->function = fc_cos;
  fc->param[ 0 ] = nfc->param[ 0 ]->copy();

  return mult;
} // differentiate_sin

KNode*
differentiate_cos( const KNodeFuncCall* nfc, QString var )
{
  // cos( n ) => -n' * sin( n )
  KNodeMult* mult = new KNodeMult;
  CHECK_PTR( mult );

  // -n'
  KNodeMult* multMinus = new KNodeMult;
  CHECK_PTR( multMinus );
  mult->left = multMinus;
  multMinus->left = new KNodeValue( -1.0 );
  CHECK_PTR( multMinus->left );
  multMinus->right = nfc->param[ 0 ]->differentiate( var );
  if ( !multMinus->right ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  
  // sin( n )
  funcCall* fc_sin = (*dictFunctions)[ "sin" ];
  if ( !fc_sin ) {
    warning( "sin unknown!" );
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* fc = new KNodeFuncCall;
  CHECK_PTR( fc );
  mult->right = fc;
  fc->function = fc_sin;
  fc->param[ 0 ] = nfc->param[ 0 ]->copy();

  return mult;
} // differentiate_cos

KNode*
differentiate_tan( const KNodeFuncCall* nfc, QString var )
{
  // tan( n ) => n' / cos( n )^2
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  div->left = nfc->param[ 0 ]->differentiate( var );
  if( !div->left ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if
  KNodePow* pow = new KNodePow;
  CHECK_PTR( pow );
  div->right = pow;
  pow->right = new KNodeValue( 2.0 );
  CHECK_PTR( pow->right );

  // cos( n )
  funcCall* fc_cos = (*dictFunctions)[ "cos" ];
  if ( !fc_cos ) {
    warning( "cos unknown!" );
    delete div;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* fc = new KNodeFuncCall;
  CHECK_PTR( fc );
  pow->left = fc;
  fc->function = fc_cos;
  fc->param[ 0 ] = nfc->param[ 0 ]->copy();

  return div;
} // differentiate_tan



KNode*
differentiate_asin( const KNodeFuncCall* nfc, QString var )
{
  // asin( n ) => n' / sqrt( 1 - x^2 )
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  div->left = nfc->param[ 0 ]->differentiate( var );
  if( !div->left ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if

  // sqrt( n )
  funcCall* fc_sqrt = (*dictFunctions)[ "sqrt" ];
  if ( !fc_sqrt ) {
    warning( "sqrt unknown!" );
    delete div;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* fc = new KNodeFuncCall;
  CHECK_PTR( fc );
  div->right = fc;
  fc->function = fc_sqrt;

  KNodeMinus* minus = new KNodeMinus;
  CHECK_PTR( minus );
  fc->param[ 0 ] = minus;
  minus->left = new KNodeValue( 1.0 );
  CHECK_PTR( minus->left );

  KNodePow* pow = new KNodePow;
  CHECK_PTR( pow );
  minus->right = pow;
  pow->left = nfc->param[ 0 ]->copy();
  pow->right = new KNodeValue( 2.0 );
  CHECK_PTR( pow );

  return div;
} // differentiate_asin

KNode*
differentiate_acos( const KNodeFuncCall* nfc, QString var )
{
  // acos( n ) => -n' / sqrt( 1 - x^2 )
  KNodeMult* mult = new KNodeMult;
  CHECK_PTR( mult );
  mult->left = new KNodeValue( -1.0 );
  CHECK_PTR( mult->left );

  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  mult->right = div;
  div->left = nfc->param[ 0 ]->differentiate( var );
  if( !div->left ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if

  // sqrt( n )
  funcCall* fc_sqrt = (*dictFunctions)[ "sqrt" ];
  if ( !fc_sqrt ) {
    warning( "sqrt unknown!" );
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if
  KNodeFuncCall* fc = new KNodeFuncCall;
  CHECK_PTR( fc );
  div->right = fc;
  fc->function = fc_sqrt;

  KNodeMinus* minus = new KNodeMinus;
  CHECK_PTR( minus );
  fc->param[ 0 ] = minus;
  minus->left = new KNodeValue( 1.0 );
  CHECK_PTR( minus->left );

  KNodePow* pow = new KNodePow;
  CHECK_PTR( pow );
  minus->right = pow;
  pow->left = nfc->param[ 0 ]->copy();
  pow->right = new KNodeValue( 2.0 );
  CHECK_PTR( pow );

  return mult;
} // differentiate_acos

KNode*
differentiate_atan( const KNodeFuncCall* nfc, QString var )
{
  // atan( n ) => n' / ( 1 + x^2 )
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  div->left = nfc->param[ 0 ]->differentiate( var );
  if( !div->left ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if

  KNodePlus* plus = new KNodePlus;
  CHECK_PTR( plus );
  div->right = plus;
  plus->left = new KNodeValue( 1.0 );
  CHECK_PTR( plus->left );

  KNodePow* pow = new KNodePow;
  CHECK_PTR( pow );
  plus->right = pow;
  pow->left = nfc->param[ 0 ]->copy();
  pow->right = new KNodeValue( 2.0 );
  CHECK_PTR( pow );

  return div;
} // differentiate_atan


KNode*
differentiate_ln( const KNodeFuncCall* nfc, QString var )
{
  // ln( n ) => n'/n
  KNodeDiv* div = new KNodeDiv;
  CHECK_PTR( div );
  div->left = nfc->param[ 0 ]->differentiate( var );
  if ( !div->left ) {
    delete div;
    return static_cast<KNode*>( NULL );
  } // if
  div->right = nfc->param[ 0 ]->copy();

  return div;
} // differentiate_ln


KNode*
differentiate_exp( const KNodeFuncCall* nfc, QString var )
{
  // exp(n) => n'*e^n
  KNodeMult* mult = new KNodeMult;
  CHECK_PTR( mult );

  mult->left = nfc->param[ 0 ]->differentiate( var );
  if ( !mult->left ) {
    delete mult;
    return static_cast<KNode*>( NULL );
  } // if

  mult->right = new KNodePow;
  CHECK_PTR( mult->right );
  mult->right->left = new KNodeValue( M_E );
  CHECK_PTR( mult->right->left );
  mult->right->right = nfc->param[ 0 ]->copy();

  return mult;
} // differentiate_exp


KNode*
funcCall::differentiate( QString var, const KNodeFuncCall* t ) const
{
  return differentiateFunc ? differentiateFunc( t, var ) :
    static_cast<KNode*>( NULL );
} // funcCall::differentiate
