您现在的位置是:首页 >技术教程 >Unity Ml-Agent猫抓老鼠实现网站首页技术教程

Unity Ml-Agent猫抓老鼠实现

最丑王子 2023-06-30 09:27:50
简介Unity Ml-Agent猫抓老鼠实现

Unity ML-Agent是一款基于Unity引擎的机器学习工具包,可用于游戏智能化、机器人控制、虚拟现实等领域。它提供了丰富的接口和功能,可让开发者轻松地使用深度学习和强化学习等算法训练智能体并优化其行为。使用ML-Agent,可以实现自主学习、决策和优化行为,从而提高游戏的可玩性、真实感和交互性,推动游戏和人工智能领域的创新和发展。

话不多说,直接开始我们今天的主题《猫抓老鼠》

相信大家在早期有观看过openai在2019年的智能体训练捉迷藏攻防,那时候把博主看的激情澎湃,不亦乐乎也很想自己尝试一下训练个属于自己的智能小人

 论文链接:https://d4mucfpksywv.cloudfront.net/emergent-tool-use/paper/Multi_Agent_Emergence_2019.pdf

环境:https://github.com/openai/multi-agent-emergence-environments

如今chatgpt的现世,给人工智能领域掀起了热潮,身为普通人的博主也想站在ai时代的风口上!

环境准备:

Anaconda:Anaconda | The World’s Most Popular Data Science Platform

Ml-agent:GitHub - Unity-Technologies/ml-agents: The Unity Machine Learning Agents Toolkit (ML-Agents) is an open-source project that enables games and simulations to serve as environments for training intelligent agents using deep reinforcement learning and imitation learning.

unity:Unity官方下载_Unity最新版_从Unity Hub下载安装 | Unity中国官网

环境配置这里就不一一讲解了。

打开项目:

首先,我们需要创建一个场景,添加一个地图和猫(蓝)和老鼠(绿)两个角色。然后,我们需要对猫和老鼠进行模型设置和脚本,使它们能够自主运动、捕获和逃脱。

构建智能体:

我们需要在场景中构建两个智能体:猫和老鼠。我们通过Unity ML-Agent提供的Agent组件,对智能体进行初始化和计算行为,从而构建出一个完整的深度神经网络模型。

猫(蓝)的奖罚机制:

抓到老鼠奖励0.1分,不扣分(在早期训练中,如果扣分会导致智能体不进行运动,认为这样是最好的决策)

老鼠(绿)的奖励机制:

被抓到扣0.1分,当智能体的y轴小于0.5时(掉落出场景)扣0.1分

代码实现:

Cat

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;

public class CatCatches : Agent
{
    public Rigidbody Mouse;
    Rigidbody Cat;
    public float speed = 30f;
    public bool Win = false;
    public override void Initialize()
    {
        Cat = GetComponent<Rigidbody>();
    }
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(Mouse.transform.localPosition);
        sensor.AddObservation(Mouse.velocity);
        sensor.AddObservation(Mouse.rotation);
        sensor.AddObservation(Mouse.angularVelocity);

        sensor.AddObservation(Cat.transform.localPosition);
        sensor.AddObservation(Cat.velocity);
        sensor.AddObservation(Cat.rotation);
        sensor.AddObservation(Cat.angularVelocity);
    }
    public override void OnActionReceived(float[] vectorAction)  
    {
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        if (Cat.transform.localPosition.y == 1.0f)
        {
            controlSignal.y = vectorAction[2] * 10.0f;

        }
        Cat.AddForce(controlSignal * speed);
        if(Cat.transform.localPosition.y < 0.5)
        {
            EndEpisode();
        }
    }
    public override void OnEpisodeBegin()
    {
        Cat.transform.localPosition = new Vector3(-4f, 0.5f, -5f);
        Cat.velocity = Vector3.zero;
        Cat.rotation = Quaternion.Euler(Vector3.zero);
        Cat.angularVelocity = Vector3.zero;

        Mouse.transform.localPosition = new Vector3(4.5f, 0.5f, 4.5f);
        Mouse.velocity = Vector3.zero;
        Mouse.rotation = Quaternion.Euler(Vector3.zero);
        Mouse.angularVelocity = Vector3.zero;
    }

    private void OnCollisionEnter(Collision collision)
    {
        if (collision.rigidbody == Mouse)
        {
            AddReward(0.1f);
        }
    }
}

Mouse

using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using MLAgents;
using MLAgents.Sensors;

public class MouseElude : Agent
{
    public Rigidbody Cat;
    Rigidbody Mouse;
    public float speed = 30f;
    public bool Win = false;
    public override void Initialize()
    {
        Mouse = GetComponent<Rigidbody>();
    }
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(Mouse.transform.localPosition);
        sensor.AddObservation(Mouse.velocity);
        sensor.AddObservation(Mouse.rotation);
        sensor.AddObservation(Mouse.angularVelocity);

        sensor.AddObservation(Cat.transform.localPosition);
        sensor.AddObservation(Cat.velocity);
        sensor.AddObservation(Cat.rotation);
        sensor.AddObservation(Cat.angularVelocity);
    }
    public override void OnActionReceived(float[] vectorAction)
    {
        Vector3 controlSignal = Vector3.zero;
        controlSignal.x = vectorAction[0];
        controlSignal.z = vectorAction[1];
        Mouse.AddForce(controlSignal * speed);
        if (Mouse.transform.localPosition.y < 0.5)
        {
            AddReward(-0.1f);
            EndEpisode();
        }
    }
    public override void OnEpisodeBegin()
    {
        Cat.transform.localPosition = new Vector3(-4f, 0.5f, -5f);
        Cat.velocity = Vector3.zero;
        Cat.rotation = Quaternion.Euler(Vector3.zero);
        Cat.angularVelocity = Vector3.zero;

        Mouse.transform.localPosition = new Vector3(4.5f, 0.5f, 4.5f);
        Mouse.velocity = Vector3.zero;
        Mouse.rotation = Quaternion.Euler(Vector3.zero);
        Mouse.angularVelocity = Vector3.zero;
    }

    private void OnCollisionEnter(Collision collision)
    {
        if (collision.rigidbody == Cat)
        {
            AddReward(-0.1f);
        }
    }
}

逻辑实现后,脚本预览:

训练智能体:

场景搭建,可以按照电脑性能来进行放置训练场景,博主是31个

打开Anaconda Prompt (anaconda3),进入自己搭建的虚拟环境ml-agents(需要网上找教程自己搭建,这里不讲解了),开始训练mlagent-learn config rainer_config.yaml --run-id-CatMoues_01 --train

回撤后,打开unity运行游戏就可以开始训练自己的模型啦

模型训练优化

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。