Program Listing for File spikefeatures.cpp

Return to documentation for file (processors/spikefeatures/spikefeatures.cpp)

// ---------------------------------------------------------------------
// This file is part of falcon-core.
//
// Copyright (C) 2021-now Neuro-Electronics Research Flanders
//
// Falcon-server is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Falcon-server is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with falcon-core. If not, see <http://www.gnu.org/licenses/>.
// ---------------------------------------------------------------------


#include "spikefeatures.hpp"

SpikeFeatures::SpikeFeatures() : IProcessor(){

    add_option("features", features_, "Selection of features to compute.", true);

    add_option("channeldepths", channel_pos_, "Relation between channel number and depth.");

    add_option(THRESHOLD, initial_threshold_,
               "Spike detection threshold in data units.");
    add_option(PEAK_LIFETIME, initial_peak_lifetime_,
               "Peak life time in samples");

    add_option("invert_signal", invert_signal_,
               "invert a signal to detect negative spikes");

}

void SpikeFeatures::Configure(const GlobalContext &context){
    features_labels_ = {};
    default_features_ = YAML::Load( "{ "
                                        "time: false,"
                                        "amplitude: false,"
                                        "slope: false,"
                                        "channel index: false,"
                                        "depth: false"
                                    "}");


    for(auto label: features_()){

        if (!default_features_[label]){
           throw ProcessingConfigureError("The " + label + " is not implemented.", name());
        }

        default_features_[label] = true;
        features_labels_.push_back(label);
    }

    if(default_features_["depth"].as<bool>() and channel_pos_().size() == 0 ){
        throw  ProcessingConfigureError("If the depth feature is selected, the channelmap option "
                                        "needs to give the mapping between channel label and depth.", name());

    }
}

void SpikeFeatures::CreatePorts(){
    data_in_port_ = create_input_port<TimeSeriesType<double>>(
                    TimeSeriesType<double>::Capabilities(ChannelRange(1, MAX_NCHANNELS)),
                    PortInPolicy(SlotRange(1, MAX_NCHANNELS)));


    decoding_out_port_ = create_output_port<ColumnsType<double>>(
                     ColumnsType<double>::Parameters(features_labels_, 0, true),
                     PortOutPolicy(SlotRange(1, MAX_NCHANNELS)));

    threshold_ = create_static_state(THRESHOLD, initial_threshold_(), true,
                                     Permission::WRITE);

    peak_lifetime_ = create_static_state(PEAK_LIFETIME, initial_peak_lifetime_(),
                                         true, Permission::WRITE);
}

void SpikeFeatures::CompleteStreamInfo() {
    // check if we have the same number of input and output slots

    if (data_in_port_->number_of_slots() != decoding_out_port_->number_of_slots()) {
        auto err_msg = "Number of output slots (" +
                std::to_string(decoding_out_port_->number_of_slots()) +
                ") on port '" + decoding_out_port_->name() +
                "' does not match number of input slots (" +
                std::to_string(data_in_port_->number_of_slots()) +
                ") on port '" + data_in_port_->name() + "'.";
        throw ProcessingStreamInfoError(err_msg, name());
    }

    for (slot_ = 0; slot_ < data_in_port_->number_of_slots(); ++slot_) {
        decoding_out_port_->streaminfo(slot_).set_stream_parameters(data_in_port_->streaminfo(slot_));
        decoding_out_port_->streaminfo(slot_).set_parameters(ColumnsType<double>::Parameters(features_labels_, 0, true));

        if(default_features_["depth"].as<bool>()){
            for(auto chan: data_in_port_->prototype(slot_).labels()){
                try {
                    channel_pos_().at(chan);
                } catch (const std::out_of_range& oor) {
                    throw  ProcessingConfigureError("The depth feature is selected but the channelmap given in input "
                                                    "does not contain the corresponding depth to the channel " + chan, name());
                }
            }
        }
    }
}

void SpikeFeatures::Prepare(GlobalContext &context) {
    // Create an independent spike detector by data stream input
    spike_detectors_.clear();

    for (slot_ = 0; slot_ < data_in_port_->number_of_slots(); ++slot_) {

        if(invert_signal_()){
            spike_detectors_.push_back(
                        std::make_unique<dsp::algorithms::SpikeDetector>(data_in_port_->prototype(slot_).ncolumns(),
                                                                         initial_threshold_(),
                                                                         initial_peak_lifetime_(),
                                                                         dsp::algorithms::SpikeDetectionSign::DOWN));
        }else{
            spike_detectors_.push_back(
                        std::make_unique<dsp::algorithms::SpikeDetector>(data_in_port_->prototype(slot_).ncolumns(),
                                                                         initial_threshold_(),
                                                                         initial_peak_lifetime_(),
                                                                         dsp::algorithms::SpikeDetectionSign::UP));
        }


    }
}


void SpikeFeatures::Process(ProcessingContext &context) {
#pragma omp parallel
{
    TimeSeriesType<double>::Data *data_in = nullptr;
    ColumnsType<double>::Data *data_out = nullptr;
    std::string channel_label;
    std::vector<double> amp;
    double max_index;
    std::string feature_str;
    size_t sample = 0;
    bool should_break=false;
    while (!context.terminated() and !should_break) {
            #pragma omp for nowait
        for (slot_ = 0; slot_ < data_in_port_->number_of_slots(); ++slot_) {

            if (!data_in_port_->slot(slot_)->RetrieveData(data_in) ){
                should_break=true;
                continue;
            }
            // claim output data buckets
            data_out = decoding_out_port_->slot(slot_)->ClaimData(true);
            spike_detectors_[slot_]->set_threshold(threshold_->get());
            spike_detectors_[slot_]->set_peak_life_time(peak_lifetime_->get());

            unsigned int spike_number = 0;
            for (sample = 0; sample < data_in->nsamples();sample++) {
                if (spike_detectors_[slot_]->is_spike<ColumnsType<double>::Data::sample_iterator>(
                            data_in->sample_timestamp(sample),
                            data_in->begin_sample(sample))) {

                    amp = spike_detectors_[slot_]->amplitudes_detected_spike();
                    max_index = std::max_element(amp.begin(), amp.end()) - amp.begin();
                    channel_label = data_in->labels()[max_index];
                    feature_str = "(";

                    if(default_features_["time"].as<bool>()){
                        data_out->set_data_sample(spike_number, "time", sample+data_in->hardware_timestamp());
                        feature_str += "time: " + std::to_string(sample+data_in->hardware_timestamp());
                    }

                    if(default_features_["channel index"].as<bool>()){

                        data_out->set_data_sample(spike_number, "channel index", std::stod(channel_label));
                        feature_str += "channel index: " + channel_label;
                    }

                    if(default_features_["amplitude"].as<bool>()){
                         data_out->set_data_sample(spike_number, "amplitude", amp[max_index]);
                         feature_str += "amplitude: " + std::to_string(data_out->data_sample(spike_number, "amplitude"));
                    }

                    if(default_features_["slope"].as<bool>()){
                        data_out->set_data_sample(spike_number, "slope",  spike_detectors_[slot_]->slopes_detected_spike()[max_index]);
                        feature_str += "slope: " + std::to_string(data_out->data_sample(spike_number, "slope"));
                    }

                    if(default_features_["depth"].as<bool>()){
                        channel_label = data_in->labels()[max_index];
                        data_out->set_data_sample(spike_number, "depth",  channel_pos_().at(channel_label));
                        feature_str += "depth: " + std::to_string(data_out->data_sample(spike_number, "depth"));
                    }

                    LOG(DEBUG) << name() << " Spike detected : " + feature_str + ") ts= " << sample << " channel =" << channel_label;
                    spike_number++;
                }


            }
            data_out->CloneTimestamps(*data_in);
            decoding_out_port_->slot(slot_)->PublishData();
            data_in_port_->slot(slot_)->ReleaseData();
        }
    }
    }
}


REGISTERPROCESSOR(SpikeFeatures)