Program Listing for File onlineencoder.cpp

Return to documentation for file (processors/onlineencoder/onlineencoder.cpp)

// ---------------------------------------------------------------------
// This file is part of falcon-core.
//
// Copyright (C) 2021-now Neuro-Electronics Research Flanders
//
// Falcon-server is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Falcon-server is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with falcon-core. If not, see <http://www.gnu.org/licenses/>.
// ---------------------------------------------------------------------

#include "onlineencoder.hpp"

OnlineEncoder::OnlineEncoder() : IProcessor(){
    add_option("model path", model_path_, "path of the hdf5 decoding model computed offline.", true);
    add_option("update model", update_frequency_, "frequency to update the model.");
    add_option("save model", save_model_, "save model after processing");
    add_option("model training", training_, "only share the model or train it as well.");
}

void OnlineEncoder::Configure(const GlobalContext &context){
    LOG(INFO) << model_path_();
    path_ = context.resolve_path(model_path_());

    if (not path_exists(path_)){
            throw ProcessingConfigureError(". The path for the model does not exists: " + path_, name());
    }
    LOG(DEBUG) << "Using decoding model in path: " << path_;

    if (not path_exists(path_+"/decoder.hdf5")){
            throw ProcessingConfigureError(". The path for the decoder does not exists: " + path_+"/decoder.hdf5", name());
    }
}

void OnlineEncoder::CreatePorts(){

    data_in_port_ = create_input_port<ColumnsType<double>>(
        ColumnsType<double>::Capabilities(ChannelRange(1, MAX_NCHANNELS),
                                          SampleRange(0, std::numeric_limits<uint32_t>::max()), true),
        PortInPolicy(SlotRange(1, MAX_NCHANNELS)));

    to_encode_ = create_follower_state("to encode", true, Permission::READ);

    shared_likelihoods_ = create_broadcaster_state("likelihoods",
                                                   (std::map<std::string, std::shared_ptr<PoissonLikelihood>>*)nullptr,
                                                   Permission::NONE);

    shared_decoder_ = create_broadcaster_state("decoder", (std::shared_ptr<Decoder>*)nullptr, Permission::NONE);

}

void OnlineEncoder::CompleteStreamInfo() {

    nslots_ = data_in_port_->number_of_slots();
    std::string stream_name, likelihood_path;
    // load multilikelihood in
    for (slot_ = 0; slot_ < nslots_; ++slot_) {
        stream_name = data_in_port_->streaminfo(slot_).stream_name();
        likelihood_path = path_ + "/" + stream_name + ".hdf5";
        if (not path_exists(likelihood_path)){
            throw ProcessingConfigureError("The path for the model does not exists: " + likelihood_path, name());
        }
        likelihoods_[stream_name] = PoissonLikelihood::load_from_hdf5( path_ + "/" + stream_name + ".hdf5");


        std::vector<std::string> feats(likelihoods_[stream_name]->event_distribution().space().specification().names());
        if(feats.size() != data_in_port_->prototype(slot_).ncolumns()){ // test if encoding feature in input
            if(likelihoods_[stream_name]->ndim_events() != data_in_port_->prototype(slot_).ncolumns()){ // test if decoding feature only in input
                throw ProcessingConfigureError("There is not the same number of features between"
                                                 " what was used to train model and the one used here."
                                                 " \nModel features are : "
                                               + join(feats,", ")
                                               + "\nFalcon features are: "
                                               + join(data_in_port_->prototype(slot_).labels(), ", "));
            }

            feats.resize(likelihoods_[stream_name]->ndim_events());
        }

        for(uint i=0; i< feats.size(); i++){
            if(feats[i].compare(data_in_port_->prototype(slot_).labels()[i])!=0){
                throw ProcessingConfigureError("The features used to train the model seems "
                                               "different than the one used here. \nModel features are : "
                                               + join(feats,", ")
                                               + "\nFalcon features are: "
                                               + join(data_in_port_->prototype(slot_).labels(), ", "));
            }
        }
        shared_likelihoods_->set(&likelihoods_);
    }
    decoder_ = Decoder::load_from_hdf5(path_ + "/decoder.hdf5");
    shared_decoder_->set(&decoder_);
}


void OnlineEncoder::Process(ProcessingContext &context){
    ColumnsType<double>::Data* data_in = nullptr;
    int npacket = 0;
    bool to_encode;
    bool should_break=false;
    while (!context.terminated()) {
        to_encode = to_encode_->get() and training_();
        for (slot_ = 0; slot_ < nslots_; ++slot_) {
            if (!data_in_port_->slot(slot_)->RetrieveData(data_in)) {
                should_break = true;
                break;
            }

            if(to_encode and data_in->nsamples() >0 ){
                likelihoods_[data_in_port_->streaminfo(slot_).stream_name()]->add_events(data_in->data());
            }

            data_in_port_->slot(slot_)->ReleaseData();
        }

        if(should_break){break;}

        if(to_encode){npacket++;}

        if(npacket >= update_frequency_()){
            shared_likelihoods_->set(&likelihoods_);
            shared_decoder_->set(&decoder_);
            npacket = 0;
        }
    }
}

void OnlineEncoder::Postprocess(ProcessingContext &context) {
    if(save_model_()){
        for (slot_ = 0; slot_ < nslots_; ++slot_) {
            likelihoods_[data_in_port_->streaminfo(slot_).stream_name()]->save_to_hdf5(
                        path_+"/falcon/"+data_in_port_->streaminfo(slot_).stream_name()+".hdf5");
        }
    }
}

REGISTERPROCESSOR(OnlineEncoder)