/* Copyright (C) 2003 Reliable Software Group 
 *                    - University of California, Santa Barbara
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 */

/* $Id: markov.cpp,v 1.23 2003/05/08 23:59:30 chris Exp $ */

#include <fstream>

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "markov.h" 

#undef not_defined__

#define CHECK_CONSISTENCY
#undef SHOW_PROBABILITY
#undef DUMP_FILES


#define SOFT_MAX 20
#define ABSOLUTE_MAX 40


/******************************************************************************/
/*                        Model Transition Methods                            */
/******************************************************************************/

bool ModelTrans::has_target(ModelState *target)
{
    return (this->target == target); 
}


/******************************************************************************/
/*                        Model State Methods                                 */
/******************************************************************************/

ModelState::ModelState(HiddenMarkovModel *m)
{
  model = m;
  id = m->get_next_state_id();

  trans_count = emission_count = 0;
}


ModelState::ModelState(HiddenMarkovModel *m, Item *token) : model(m)
{
  model = m;
  id = m->get_next_state_id();

  trans_count = 0;

  /* lock the item and add it to the emissions */
  token->lock();
  
  emissions[token] = 1;
  emission_count = 1;
}

ModelState::~ModelState()
{
  TransIter iter;

  for (iter = trans_forward.begin(); iter != trans_forward.end(); ++iter)
    delete iter->second;

  for (iter = trans_back.begin(); iter != trans_back.end(); ++iter)
    delete iter->second;
}

long double ModelState::reset_structural_probability()
{
#ifdef SHOW_PROBABILITY
  cout << "called reset_structural_probability() on " << id << "\n";
#endif

  long double result = (last_structure / p_structure);
  p_structure = last_structure;
  return result;
}

long double ModelState::get_structural_probability(bool reduced, bool overwrite)
{
  long double p = 1.0;

  if (reduced) 
    p *= pow((long double) (model->get_model_size()), (-1.0) * (long double) (trans_forward.size()));
  else 
    p *= pow((long double) (model->get_model_size() + 1), (-1.0) * (long double) (trans_forward.size()));

  p *= pow((long double) (model->get_alphabet_size() + 1), (-1.0) *(long double) (emissions.size()));

#ifdef SHOW_PROBABILITY
  cout << "called get_structural_probability() on " << id << " = "  << p  << " and set = " << overwrite << "\n";
#endif

  if (overwrite)
    last_structure = p_structure;
  p_structure = p;

  return p;
}

long double ModelState::get_last_structural_probability()
{
#ifdef SHOW_PROBABILITY
  cout << "called get_last_structural_probability() on " << id << " = "  << last_structure << "\n";
#endif
  return last_structure;
}



long double ModelState::reset_likelihood_probability()
{
#ifdef SHOW_PROBABILITY
  cout << "called reset_likelihood_probability() on " << id << "\n";
#endif

  long double result = (last_likelihood / p_likelihood);
  p_likelihood = last_likelihood;
  return result;
}

long double ModelState::get_likelihood_probability(bool overwrite)
{
  long double tmp, p = 1.0;
  EmissionIter emit_iter;
  TransIter trans_iter;

  if ((trans_forward.size() != 0 && trans_count == 0) || (emissions.size() != 0 && emission_count == 0))
    return 0.0;

  for (trans_iter = trans_forward.begin(); trans_iter != trans_forward.end(); ++trans_iter) {
    tmp = ((long double) trans_iter->second->count / (double) trans_count);
    p *= pow(tmp, (long double) trans_iter->second->count);
  }

  for (emit_iter = emissions.begin(); emit_iter != emissions.end(); ++emit_iter) {
    tmp = ((long double) emit_iter->second / (long double) emission_count);
    p *= pow(tmp, (long double) emit_iter->second);
  }

  if (overwrite)
    last_likelihood = p_likelihood;
  p_likelihood = p;

#ifdef SHOW_PROBABILITY
  cout << "called get_likelihood_probability() on " << id << " = "  << p  << " and set = " << overwrite << "\n";
#endif

  return p;
}

long double ModelState::get_last_likelihood_probability()
{
#ifdef SHOW_PROBABILITY
  cout << "called get_last_likelihood_probability() on " << id << " = "  << last_likelihood << "\n";
#endif

  return last_likelihood;
}

void ModelState::check_consistency()
{
  unsigned int count;
  EmissionIter emit_iter; 
  TransIter trans_iter; 

  /* is the number of emissions consistent */
  for (count = 0, emit_iter = emissions.begin(); emit_iter != emissions.end(); ++emit_iter)
    count += emit_iter->second;
  if (count != emission_count)
    throw ModelConsistencyException("corrupted emission count");

  /* is the number of forward transitions consistent */
  for (count = 0, trans_iter = trans_forward.begin(); trans_iter != trans_forward.end(); ++trans_iter)
    count += trans_iter->second->count;
  if (count != trans_count)
    throw ModelConsistencyException("corrupted transition count");


  /* is there a back transition for each forward transition */
  for (trans_iter = trans_forward.begin(); trans_iter != trans_forward.end(); ++trans_iter) {
      ModelState *target;
      ModelTrans *back_ptr;
 
      target = trans_iter->second->target;
      if (((back_ptr = target->trans_back[this->id]) == 0) || (back_ptr->count != trans_iter->second->count))
	  throw ModelConsistencyException("corrupted forward transition");
  }

  /* is there a forward transition for each back transition */
  for (trans_iter = trans_back.begin(); trans_iter != trans_back.end(); ++trans_iter) {
      ModelState *target;
      ModelTrans *forward_ptr;
 
      target = trans_iter->second->target;
      if (((forward_ptr = target->trans_forward[this->id]) == 0) || (forward_ptr->count != trans_iter->second->count))
	  throw ModelConsistencyException("corrupted back transition");
  }
}

unsigned int ModelState::get_id()
{
    return id;
}
 
void ModelState::add_forward_edge(ModelState *target, unsigned int count)
{
    TransIter iter;

    iter = trans_forward.find(target->id);

    if (iter == trans_forward.end()) {
	/* no transition to target exists */
	trans_forward[target->id] = new ModelTrans(target, count);
    }
    else {
	/* transition exists - increase count */
	iter->second->count += count;
    }

    trans_count += count;
}

void ModelState::subtract_forward_edge(ModelState *target, unsigned int count)
{
    TransIter iter;
    ModelTrans *trans;

    iter = trans_forward.find(target->id);

    if (iter == trans_forward.end()) 
	throw ModelConsistencyException("trying to remove non-exsiting forward edge");
    else
      trans = iter->second;
    
    if (trans->count < count) 
	throw ModelConsistencyException("trying to remove too many counts from forward edge");
    else if (trans->count > count) { 
	/* rmove the counts from edge */
	trans->count -= count;
    }
    else {
	/* completely remove edge */
	trans_forward.erase(target->id);
	delete trans;
    }

    if (trans_count < count) 
	throw ModelConsistencyException("trans_count drops below zero");
    else
	trans_count -= count;
}

void ModelState::add_back_edge(ModelState *target, unsigned int count)
{
    TransIter iter;

    iter = trans_back.find(target->id);

    if (iter == trans_back.end()) {
	/* no transition to target exists */
	trans_back[target->id] = new ModelTrans(target, count);
    }
    else {
	/* transition exists - increase count */
	iter->second->count += count;
    }
}

void ModelState::subtract_back_edge(ModelState *target, unsigned int count)
{
    TransIter iter;
    ModelTrans *trans;

    iter = trans_back.find(target->id);

    if (iter == trans_back.end()) 
	throw ModelConsistencyException("trying to remove non-exsiting back edge");
    else
      trans = iter->second;
    
    if (trans->count < count) 
      throw ModelConsistencyException("trying to remove too many counts from back edge");
    else if (trans->count > count) { 
      /* rmove the counts from edge */
      trans->count -= count;
    }
    else {
      /* completely remove edge */
      trans_back.erase(target->id);
      delete trans;
    }
}

void ModelState::merge_forward(ModelState *state, set<unsigned int> *affected)
{
  TransIter iter;
  ModelState *target;
  ModelTrans *trans;

  for (iter = state->trans_forward.begin(); iter != state->trans_forward.end(); ++iter) {

      trans = iter->second;
      target = trans->target;

      if (affected != 0) 
	affected->insert(target->id);

      if (trans->has_target(state)) {
	  /* back edge is handled automatically */
	  this->add_forward_edge(this, trans->count);
      }
      else {
	  this->add_forward_edge(target, trans->count);
	  target->add_back_edge(this, trans->count);
	  target->subtract_back_edge(state, trans->count);
      }
  }
}

void ModelState::merge_back(ModelState *state, set<unsigned int> *affected)
{
  TransIter iter;
  ModelState *target;
  ModelTrans *trans;

  for (iter = state->trans_back.begin(); iter != state->trans_back.end(); ++iter) {

      trans = iter->second;
      target = trans->target;

      if (affected != 0) 
	affected->insert(target->id);

      if (trans->has_target(state)) {
	  /* back edge is handled automatically */
	  this->add_back_edge(this, trans->count);
      }
      else {
	  this->add_back_edge(target, trans->count);
	  target->add_forward_edge(this, trans->count);
	  target->subtract_forward_edge(state, trans->count);
      }
  }
}

void ModelState::backtrack_forward(ModelState *state, set<unsigned int> *affected)
{
  TransIter iter;
  ModelState *target;
  ModelTrans *trans;

  for (iter = state->trans_forward.begin(); iter != state->trans_forward.end(); ++iter) {

      trans = iter->second;
      target = trans->target;

      if (affected != 0) 
	affected->insert(target->id);

      if (trans->has_target(state)) {
	  /* back edge is handled automatically */
	  this->subtract_forward_edge(this, trans->count);
      }
      else {
	  this->subtract_forward_edge(target, trans->count);
	  target->subtract_back_edge(this, trans->count);
	  target->add_back_edge(state, trans->count);
      }
  }
}

void ModelState::backtrack_back(ModelState *state, set<unsigned int> *affected)
{
  TransIter iter;
  ModelState *target;
  ModelTrans *trans;

  for (iter = state->trans_back.begin(); iter != state->trans_back.end(); ++iter) {

      trans = iter->second;
      target = trans->target;

      if (affected != 0) 
	affected->insert(target->id);

      if (trans->has_target(state)) {
	  /* back edge is handled automatically */
	  this->subtract_back_edge(this, trans->count);
      }
      else {
	  this->subtract_back_edge(target, trans->count);
	  target->subtract_forward_edge(this, trans->count);
	  target->add_forward_edge(state, trans->count);
      }
  }
}

/* add a token to the states emissions */
void ModelState::add_emission(Item *token, unsigned int count)
{
  EmissionIter found;
  
  found = emissions.find(token);
	
  if (found != emissions.end()) 
    found->second += count;
  else 
    emissions[token] = count;

  emission_count += count;
}

/* merges the emissions of state s2 into the current one */
void ModelState::merge_emissions(ModelState *state)
{
    EmissionIter iter;

    for (iter = state->emissions.begin(); iter != state->emissions.end(); ++iter) 
      add_emission(iter->first, iter->second);
}

/* take back the merge of emission of s2 into current state */ 
void ModelState::backtrack_emissions(ModelState *state)
{
  EmissionIter iter, found;

  for (iter = state->emissions.begin(); iter != state->emissions.end(); ++iter) {
      
      found = emissions.find(iter->first);

      if (found == emissions.end())
	  throw ModelConsistencyException("backtrack_emissions: subtracting non-existing input token");

      if (found->second < iter->second)
	  throw ModelConsistencyException("backtrack_emissions: trying to subtract more input tokens than possible");
      else if (found->second > iter->second) 
	  found->second -= iter->second;
      else 
	  emissions.erase(found->first);

      if (emission_count < iter->second) 
	  throw ModelConsistencyException("backtrack_emissions: inconsistent number of emissions - dropping below zero");
      else
	  emission_count -= iter->second;
  }
}

/* find the first successor node that emits 'token'  */
ModelState* ModelState::find_successor(Item *token)
{
  TransIter trans;
  ModelState *target, *result;
  bool found = false;

  /* make sure that only a single successor is found - else return 0 */
  for (trans = trans_forward.begin(); trans != trans_forward.end(); ++trans) {
    target = trans->second->target;

    if (target->emissions.find(token) != target->emissions.end()) {
      if (found)
	return 0; /* more than one successor */
      else {
	result = target; found = true;
      }
    }
  }

  if (found)
    return result;
  else
    return 0;
}

/* print out a dot (graph) representation to 'out' */
void ModelState::to_string(ostream &out)
{
  EmissionIter emit;
  TransIter trans_iter;
  ModelTrans *trans;

  out << "   node_" << id << " [ label = \"" << id << "\\n{";

  /* print node attributes - i.e. emission */
  for (emit = emissions.begin(); emit != emissions.end(); ++emit) {
    emit->first->to_string(out);
    out << "/" << emit->second << " ";
  }

  out <<  "}\"];\n";

  /* print nodes forward transitions */
  out.precision(2);
  for (trans_iter = trans_forward.begin(); trans_iter != trans_forward.end(); ++trans_iter) {
    trans = trans_iter->second;
    out << "   node_" << id << " -> node_" << trans->target->id << " [ label = \"" << ((long double) trans->count / (long double) trans_count) << "\"];\n";
  }
}

void ModelState::print_state()
{
  cout << "State " << id << ":\n";
  cout << "   p_s = " << p_structure << "\n";
  cout << "   p_last_s = " << last_structure << "\n";
  cout << "   p_l = " << p_likelihood << "\n";
  cout << "   p_last_l = " << last_likelihood << "\n";
}

/******************************************************************************/
/*                        Markov Model Methods                                */
/******************************************************************************/

HmmImpl::HmmImpl() 
{
  unique = 0;
  probability = 0.0;
  valid_probability = valid_reduced = false;
  _inserted = 0;

  start = new ModelState(this);
  terminal = new ModelState(this);
}

HmmImpl::~HmmImpl()
{
  list<Item *> kill_list;
  EmissionIter emit;
  StateIter state;
  __gnu_cxx::hash_map<Item*, unsigned char>::iterator alphabet_iter;

  /* put all references to items from the state emission tables and
     the alphabet map into a kill list
  */
  for (emit = start->emissions.begin(); emit != start->emissions.end(); ++emit) {
    kill_list.push_back((*emit).first);
  }
  for (emit = terminal->emissions.begin(); emit != terminal->emissions.end(); ++emit) {
    kill_list.push_back((*emit).first);
  }
  for (state = states.begin(); state != states.end(); ++state) 
    for (emit = state->second->emissions.begin(); emit != state->second->emissions.end(); ++emit) {
      kill_list.push_back((*emit).first);
    }
  
  for (alphabet_iter = alphabet.begin(); alphabet_iter != alphabet.end(); ++alphabet_iter) {
    kill_list.push_back((*alphabet_iter).first);
  }

  /* then, delete all states */
  delete start;
  delete terminal;
  for (state = states.begin(); state != states.end(); ++state)
    delete state->second;

  /* release the kill_list items */
  list<Item *>::iterator kill_iter;
  for (kill_iter = kill_list.begin(); kill_iter != kill_list.end(); ++kill_iter)
    (*kill_iter)->release();
}


int HmmImpl::get_next_state_id()
{
  return unique++;
}

long double HmmImpl::eval_likelihood(bool overwrite)
{
  long double p;
  StateIter iter; 

  p = start->get_likelihood_probability(overwrite);
  for (iter = states.begin(); iter != states.end(); ++iter) {
    p *= iter->second->get_likelihood_probability(overwrite);
  }

  return p;
}

long double HmmImpl::eval_structure(bool reduced, bool overwrite)
{
  long double p;
  StateIter iter; 

  p = start->get_structural_probability(reduced, overwrite);
  
  for (iter = states.begin(); iter != states.end(); ++iter) {
    p *= iter->second->get_structural_probability(reduced, overwrite);
  }

  return p;
}

void HmmImpl::check_consistency() 
{
  StateIter iter;  
  
  if ((start == 0) || (start->get_id() != 0))
      throw ModelConsistencyException("start state corrupted");

  if ((terminal == 0) || (terminal->get_id() != 1))
      throw ModelConsistencyException("terminal state corrupted");

  for (iter = states.begin(); iter != states.end(); ++iter)
      iter->second->check_consistency();

}

void HmmImpl::update_merge_probability(set<unsigned int> *affected, unsigned int merged)
{
  set<unsigned int>::iterator iter;
  ModelState *state;
  long double old_struct, old_like, new_struct, new_like;
  StateIter state_iter;

  for (iter = affected->begin(); iter != affected->end(); ++iter) {

    if (*iter == start->get_id())
      state = start;
    else if ((*iter == terminal->get_id()) || (*iter == merged))
      continue;
    else {
      state_iter = states.find(*iter);
      if (state_iter == states.end()) {
	cout << "tried to locate " << *iter;
	print_states();
	throw ModelConsistencyException("HmmImpl::update_merge_probability could not find expected state");
      }
      else
	state = state_iter->second;
    }

    /* derive the new, modified probability values (and store them in the states) */
    new_struct = state->get_structural_probability(false, true);
    new_like = state->get_likelihood_probability(true);

    /* retrieve the old probability values */
    old_struct = state->get_last_structural_probability();
    old_like = state->get_last_likelihood_probability();
    
    /* update the overall model probability */
    long double modifier = ((new_struct / old_struct) * (new_like / old_like));
    this->probability *= modifier;
  }
}

void HmmImpl::update_backtrack_probability(set<unsigned int> *affected, unsigned int merged)
{
  set<unsigned int>::iterator iter;
  ModelState *state;
  long double reset_struct, reset_like;
  StateIter state_iter;

  for (iter = affected->begin(); iter != affected->end(); ++iter) {

    if (*iter == start->get_id()) 
      state = start;
    else if ((*iter == terminal->get_id()) || (*iter == merged))
      continue;
    else {
      state_iter = states.find(*iter);
      if (state_iter == states.end())
	throw ModelConsistencyException("HmmImpl::update_backtrack_probability could not found expected state");
      else
	state = state_iter->second;
    }

    /* reset the old probability values */
    reset_struct = state->reset_structural_probability();
    reset_like = state->reset_likelihood_probability();

    /* update the overall model probability */
    long double modifier = (reset_struct * reset_like);
    this->probability *= modifier;
  }
}

/* merges state 's2' into state 's1' (both of hidden markov model 'model') */
ModelUpdate* HmmImpl::merge_states(ModelState *s1, ModelState *s2)
{
  ModelUpdate *update;
  EmissionIter emit_iter;
  set<unsigned int> affected;

  // cout << "merge " << s2->get_id() << " into " << s1->get_id() << "\n";

  /* allocate and fill update structure */
  update = new ModelUpdate(s1, s2);
  
  /* set s1 as affected, s2 is handled later */
  affected.insert(s1->get_id());
  
  /* add emissions of s2 to s1 */
  s1->merge_emissions(s2);

  /* update transitions - merge outgoing transitions */
  s1->merge_forward(s2, &affected);
  
  /* update transitions - merge back pointers */
  s1->merge_back(s2, &affected);

  /* remove s2 from valid state map and put it into dangling */
  states.erase(s2->get_id());
  dangling[s2->get_id()] = s2;

  /* update the probability of the model, if possible */
  if (valid_reduced) {
    update_merge_probability(&affected, s2->get_id());

    /* now account for the removed state */
    probability /= (s2->get_structural_probability(false, true) * s2->get_likelihood_probability(true));

    /* now - probability is valid, too */
    valid_probability = true;
  }

  return update; 
} 

/* takes back (undo) a merge of s2 into s1 (as reflected by 'update') */
void HmmImpl::backtrack_states(ModelUpdate *update)
{
  ModelState *s1, *s2;
  set<unsigned int> affected;
  s1 = update->get_merge_dest(); s2 = update->get_merge_source();

  // cout << "split " << s2->get_id() << " from " << s1->get_id() << "\n";

  /* set s1 as affected, account for s2 later */
  affected.insert(s1->get_id());

  /* remove emissions of s2 from s1 */
  s1->backtrack_emissions(s2);

  /* backtrack transitions - backtrack outgoing transitions */
  s1->backtrack_forward(s2, &affected);  

  /* backtrack transitions - backtrack back pointers */
  s1->backtrack_back(s2, &affected);

  /* remove s2 from dangling states and put it back into regular state */
  dangling.erase(s2->get_id());
  states[s2->get_id()] = s2;

  /* update the probability of the model, if possible */
  if ((valid_probability) && (valid_reduced)) {
    update_backtrack_probability(&affected, s2->get_id());

    /* now account for the removed state - just reset values */
    probability *= (s2->get_structural_probability(true, true) * s2->get_likelihood_probability(true));

    /* probability is no longer valid */
    valid_probability = false;
  }

  /* free update structure - not needed any longer */
  delete update;
}

void HmmImpl::finalize(ModelState *dst, ModelState *src)
{
  /* store all items that are in the emissions of both states, src and
     dst. these items will not be needed after the merge and should be
     deleted 
  */
  list<Item *> kill_list;
  EmissionIter iter, found;
  for (iter = src->emissions.begin(); iter != src->emissions.end(); ++iter) {
    found = dst->emissions.find((*iter).first);
    if (found != dst->emissions.end()) {
      // when the item can be found, add it to the kill_list 
      kill_list.push_back((*iter).first);
    }
  }

  /* do the final merge */
  ModelUpdate *update = merge_states(dst, src);

  /* free dangling state */
  dangling.erase(src->get_id());
  delete src;

  /* free update structure - not needed any longer */
  delete update;

  /* release the kill_list items */
  list<Item *>::iterator kill_iter;
  for (kill_iter = kill_list.begin(); kill_iter != kill_list.end(); ++kill_iter)
    (*kill_iter)->release();

  valid_reduced = false;
}

void HmmImpl::print_states()
{
  StateIter iter;
  ModelState *state;
 
  start->print_state();
 
  for (iter = states.begin(); iter != states.end(); ++iter) {

    state = iter->second;
    state->print_state();
  }
}


/******************************************************************************/
/*                        API FUNCTIONS                                       */
/******************************************************************************/

/* get the number of states */
unsigned int HmmImpl::get_model_size()
{
  return states.size();
}

/* get the number of letters in input alphabet */
unsigned int HmmImpl::get_alphabet_size()
{
  return alphabet.size();
}

/* insert a new sequence of input items into the model */
void HmmImpl::_insert_item(ListCollection &input_sequence)
{
  ModelState *next, *current = start;
  ListCollection::iterator iter;

  /* is the new input item already represented by the automaton */
  if (_check_item(&input_sequence) > 0.0) 
    return;
  
  /* if input sequence is too large, silently drop it */
  if ((input_sequence.size() + get_model_size()) > ABSOLUTE_MAX) 
    return;

  /* new evidence that needs to be inserted - do prefix compression */
  for (iter = input_sequence.begin(); iter != input_sequence.end(); ++iter) {
    
    /* add input token to alphabet when it is not already there */
    Item *input_element = *iter;
    if (alphabet.find(input_element) == alphabet.end()) {
      alphabet[input_element] = 0;
      input_element->lock();
    }
      
    if ((next = current->find_successor(*iter)) != 0) {
      /* in this case, add_emission will never add (*iter) to the
	 hash_map because it is already there */ 
      next->add_emission(*iter, 1);
    }
    else {
      next = new ModelState(this, *iter);
      states[next->get_id()] = next;
    }

    current->add_forward_edge(next, 1);
    next->add_back_edge(current, 1);

    current = next;

    /* a new input element has been inserted */
    ++_inserted;
  }

  current->add_forward_edge(terminal, 1);
  terminal->add_back_edge(current, 1);

  /* reset validity of probabilities */
  this->valid_probability = false; this->valid_reduced = false;

  /* compress the automaton */
  if (get_model_size() > SOFT_MAX)
    this->optimize();
}

long double HmmImpl::get_probability()
{
  return get_probability(true);
}

/* get the current probability of the model */
long double HmmImpl::get_probability(bool overwrite)
{
  if (!valid_probability) {
    probability = eval_structure(false, overwrite) * eval_likelihood(overwrite);
    valid_probability = true;
    valid_reduced = false;
  }
  return probability;
}

/* get the warped probability for effective calculation */
void HmmImpl::set_reduced_probability()
{
  if (!valid_reduced) {
    probability = eval_structure(true, true) * eval_likelihood(true);
    valid_reduced = true;
    valid_probability = false; 
  }
}

#ifdef DUMP_FILES
static int step_count = 0;
#endif

/* merge states to optimize a-posteriori probability of model with respect to input */
void HmmImpl::optimize()
{
  StateIter siter;
  long double model_value, current_best, current;
  ModelState *opt1, *opt2;
  list<ModelState *> state_list;
  list<ModelState *>::iterator siter1, siter2;

  model_value = get_probability();

  while (1) {

#ifdef CHECK_CONSISTENCY
    check_consistency();
#endif
    current_best = 0.0;

    state_list.clear();

    for (siter = states.begin(); siter != states.end(); ++siter) 
      state_list.push_back(siter->second);
    
    for (siter1 = state_list.begin(); siter1 != state_list.end(); ++siter1) {

      siter2 = siter1; 
      ++siter2;
      for (; siter2 != state_list.end(); ++siter2) {

	check_consistency();

	set_reduced_probability();
	ModelUpdate *update = merge_states(*siter1, *siter2);

	current = get_probability();
	if (current > current_best) {
	  current_best = current;
	  opt1 = *siter1; opt2 = *siter2;
	}

	backtrack_states(update);
      }
    }

    // cout << "current best " << current_best << " vs. model_value " << model_value << "\n";

    if ((get_model_size() > SOFT_MAX) || (current_best > model_value)) {
      // cout << "merging " << opt2->get_id() << " into " << opt1->get_id() << " and model value is " << current_best << "\n"; 
      finalize(opt1, opt2);
      model_value = current_best;
    }
    else
      break;
  }
  
#ifdef DUMP_FILES
  char c_string[64];
  sprintf(c_string, "step%d.dot", ++step_count);
  string s = string(c_string);
  ofstream of(s.c_str());
  to_string(of);
#endif
}

long double HmmImpl::_check_item(ListCollection *sequence)
{
  StateIter siter;
  TransIter trans;
  ModelState *state, *to;
  ListCollectionIterator list_iter;
  __gnu_cxx::hash_map<unsigned int, long double> mapA, mapB, *current, *target, *swap;
  StepIter step;

  current = &mapA; target = &mapB;

  (*current)[start->get_id()] = 1.0;

  for (list_iter = sequence->begin(); list_iter != sequence->end(); ++list_iter) {

    /* early exit - if current is empty, return immediately */
    if (current->size() <= 0)
      return 0.0;

    /* get current input symbol */
    Item *input_item = *list_iter;

    /* lookup this item in all forward links of the elements of the current set */
    for (step = current->begin(); step != current->end(); ++step) {

      /* get the referenced state */
      if (step->first <= 1) {
	state = (step->first == 0) ? start : terminal;
      }
      else {
	siter = states.find(step->first);
	if (siter == states.end()) {
	  cout << "state not found is " << step->first;
	  ofstream faulty("faulty.dot");
	  to_string(faulty);
	  throw ModelConsistencyException("markov model contains links to non existing states");
	}
	state = siter->second;
      }

      /* lookup the target (i.e. to) states for current input symbol */
      for (trans = state->trans_forward.begin(); trans != state->trans_forward.end(); ++trans) {
	to = trans->second->target;

	if (to->emissions.find(input_item) != to->emissions.end()) {
	  
	  StepIter target_step;
	  if ((target_step = target->find(to->get_id())) != target->end()) 
	    target_step->second += ((step->second) * ((long double) trans->second->count / (long double) state->trans_count));
	  else
	    (*target)[to->get_id()] = ((step->second) * ((long double) trans->second->count / (long double) state->trans_count));
	}
      }
    }

    /* clear old table and switch current with target */
    current->clear();
    swap = current; current = target; target = swap;
  }

  long double result = 0.0;

  /* check if terminal state is reachable from current set */
  for (step = current->begin(); step != current->end(); ++step) {

    /* get the referenced state */
    if (step->first <= 1) {
      state = (step->first == 0) ? start : terminal;
    }
    else {
      siter = states.find(step->first);
      if (siter == states.end()) {
	cout << "state not found is " << step->first;
	ofstream faulty("faulty.dot");
	to_string(faulty);
	throw ModelConsistencyException("markov model contains links to non existing states");
      }
      state = siter->second;
    }

    /* is terminal state reachable from here  - if so, add probability to result */
    if ((trans = state->trans_forward.find(terminal->get_id())) != state->trans_forward.end())
      result += (step->second) * (((long double) trans->second->count / (long double) state->trans_count));
  }

  /* turn this test into a binary test */
  if (result > 0.0)
    return 1.0;
  else
    return 0.0;
}


/* print out a dot (graph) representation to 'out' */
void HmmImpl::to_string(ostream &out)
{
  StateIter state;

  out << "digraph grammar {\n";

  start->to_string(out);
  terminal->to_string(out);
  for (state = states.begin(); state != states.end(); ++state)
    state->second->to_string(out);

  out << "}";
}


/* insert an item into the model */
void HmmImpl::insert_item(Item *item) throw (ModelInputException)
{
  ListCollection *_item = dynamic_cast<ListCollection *>(item);

  if (_item == 0)
    throw ModelInputException("ListCollection required for HmmImpl::insert_item");

  _insert_item(*_item);
}

/* switch to different mode */
void HmmImpl::switch_mode(ModelMode)
{
  /* debug output
  cout << "dumping data";
  { 
  char filename[] = "model.XXXXXX";
  if(! mkstemp(filename)) {
  cerr << "Error making temp file\n";
  exit(1);
  }
  ofstream os(filename);
  to_string(os);
  os.close();
  }
  optimize();  
  */
}

/* check an item to accordance with the model */
double HmmImpl::check_item(Item *item) throw (ModelInputException)
{
  ListCollection *_item = dynamic_cast<ListCollection *>(item);
  
  if (_item == 0)
    throw ModelInputException("ListCollection required for HmmImpl::insert_item");
  
  return (double) _check_item(_item);
}

/* return confidence value */
double HmmImpl::get_confidence()
{
  /* 
   * when the number of inserted elements is higher than absolute maximum of states
   * we consider the model to be accurate 
   */
  if (_inserted >= ABSOLUTE_MAX)
    return 1.0;

  if (_inserted == 0)
    return 0;

  /*
   * else, calculate the confidence as a linear combination of
   * 1. the ratio of inserted elements / absolute maximum of states and
   * 2. the ratio of inserted elements / current states in the model
   *    this ratio shows how well the model is generalized. when the
   *    number of inserted elements is close to the number of states,
   *    no reduction = generalization took place.
   */
  return 1.0 - (((double) (get_model_size() * (ABSOLUTE_MAX - _inserted))) / ((double) (ABSOLUTE_MAX * _inserted)));
}
  
bool HmmImpl::module_test(bool output)
{
  ListCollection input_sequence;
  ListCollectionIterator list;
  long double result;

  if (output) cerr << "Inserting 'abab'\n";
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  insert_item(&input_sequence);
  for (list = input_sequence.begin(); list != input_sequence.end(); ++list)
    (*list)->release();
  input_sequence.clear();


  if (output) cerr << "Inserting 'ab'\n";
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  insert_item(&input_sequence);
  for (list = input_sequence.begin(); list != input_sequence.end(); ++list)
    (*list)->release();
  input_sequence.clear();

  if (output) cerr << "Building the automaton\n";
  optimize();

  double confidence = get_confidence();
  if (confidence <= 1.0) { 
    if (output) cerr << "Confidence = " << confidence << " ... ok\n";
  }
  else {
    if (output) cerr << "Confidence = " << confidence << " ... failed\n";
    return false;
  }

  if (output) cerr << "Test1: 'ab' -  ";
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  result = check_item(&input_sequence);
  if (output) cerr << result;
  if (result == 1.0) {
    if (output) cerr << " ... ok\n";
  }
  else {
    if (output) cerr << " ... failed\n";
    return false;
  }
  for (list = input_sequence.begin(); list != input_sequence.end(); ++list)
    (*list)->release();
  input_sequence.clear();
  
  if (output) cerr << "Test2: 'ababab' - ";
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('b'));
  result = check_item(&input_sequence);
  if (output) cerr << result;
  if (result == 1.0) {
    if (output) cerr << " ... ok\n";
  }
  else {
    if (output) cerr << " ... failed\n";
    return false;
  }
  for (list = input_sequence.begin(); list != input_sequence.end(); ++list)
    (*list)->release();
  input_sequence.clear();

  if (output) cerr << "Test3: 'aca' - ";
  input_sequence.push_back(new CharItem('a'));
  input_sequence.push_back(new CharItem('c'));
  input_sequence.push_back(new CharItem('a'));
  result = check_item(&input_sequence);
  if (output) cerr << result;
  if (result == 0.0) {
    if (output) cerr << " ... ok\n";
  }
  else {
    if (output) cerr << " ... failed\n";
    return false;
  }
  for (list = input_sequence.begin(); list != input_sequence.end(); ++list)
    (*list)->release();
  input_sequence.clear();

  return true;
}

bool HiddenMarkovModel::test(bool output)
{
  if (output) cerr << "Regression Test for Class libAnomaly::HiddenMarkovModel\n";
  if (output) cerr << "Allocated Objects -- " << Item::get_allocated()  << "\n";

  HmmImpl *model = new HmmImpl();
  bool rval = model->module_test(output);
  delete model;

  /* all tests have been successful */
  if (output) cerr << "Allocated Objects (should be equal to number above) -- " << Item::get_allocated()  << "\n";
  if (output) cerr << "\n";

  return rval;
}

HiddenMarkovModel::~HiddenMarkovModel()
{
}


Model *HiddenMarkovModel::instance()
{
  return new HmmImpl();
};
