Class OnnxPolicy
Defined in File OnnxPolicy.h
Inheritance Relationships
Base Type
public legged::Policy
(Class Policy)
Class Documentation
-
class OnnxPolicy : public legged::Policy
Public Types
-
using tensor_element_t = float
-
using tensor2d_t = Eigen::Matrix<tensor_element_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
Public Functions
-
explicit OnnxPolicy(const std::string &modelPath)
-
inline virtual size_t getObservationSize() const override
-
inline virtual size_t getActionSize() const override
-
virtual void init() override
-
inline virtual void reset() override
-
inline virtual vector_t getLastAction() override
-
virtual vector_t forward(const vector_t &observations) override
Protected Functions
-
inline std::string getMetadataStr(const std::string &key) const
-
inline vector_t getMetadataVector(const std::string &key, bool verbose = true) const
-
virtual void parseMetadata()
-
virtual void parseInputOutput()
-
virtual void checkInputOutput()
-
void run()
Protected Attributes
-
std::shared_ptr<Ort::Env> onnxEnvPrt_
-
std::unique_ptr<Ort::Session> sessionPtr_
-
std::map<std::string, std::string> name2Metadata_
-
std::vector<Ort::AllocatedStringPtr> inputNamesRaw_
-
std::vector<Ort::AllocatedStringPtr> outputNamesRaw_
-
std::vector<const char*> inputNames_
-
std::vector<const char*> outputNames_
-
std::vector<std::vector<int64_t>> inputShapes_
-
std::vector<std::vector<int64_t>> outputShapes_
-
std::map<std::string, size_t> name2Index_
-
std::vector<tensor2d_t> inputTensors_
-
std::vector<tensor2d_t> outputTensors_
-
using tensor_element_t = float