Class OnnxPolicy

Inheritance Relationships

Base Type

Class Documentation

class OnnxPolicy : public legged::Policy

Public Types

using SharedPtr = std::shared_ptr<OnnxPolicy>
using tensor_element_t = float
using tensor2d_t = Eigen::Matrix<tensor_element_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>

Public Functions

explicit OnnxPolicy(const std::string &modelPath)
inline virtual size_t getObservationSize() const override
inline virtual size_t getActionSize() const override
virtual void init() override
inline virtual void reset() override
inline virtual vector_t getLastAction() override
virtual vector_t forward(const vector_t &observations) override

Protected Functions

inline std::string getMetadataStr(const std::string &key) const
inline vector_t getMetadataVector(const std::string &key, bool verbose = true) const
virtual void parseMetadata()
virtual void parseInputOutput()
virtual void checkInputOutput()
void run()

Protected Attributes

std::shared_ptr<Ort::Env> onnxEnvPrt_
std::unique_ptr<Ort::Session> sessionPtr_
std::map<std::string, std::string> name2Metadata_
std::vector<Ort::AllocatedStringPtr> inputNamesRaw_
std::vector<Ort::AllocatedStringPtr> outputNamesRaw_
std::vector<const char*> inputNames_
std::vector<const char*> outputNames_
std::vector<std::vector<int64_t>> inputShapes_
std::vector<std::vector<int64_t>> outputShapes_
std::map<std::string, size_t> name2Index_
std::vector<tensor2d_t> inputTensors_
std::vector<tensor2d_t> outputTensors_