1. パーセプトロンのアニメーション(3D)

    前回に引き続き3Dバージョンを作ってみました。

    http://kokonotsu.net/swf/perceptron/perceptron3D.html



    収束するのにちょっと時間がかかるようになっていますね。次元が増えているのでまあそういうものなのでしょうか。たまに線形分離できないデータができてるぽいですが分離平面の式がおかしいのかな。

    ソースはこちら。Away3Dを使っています。

    package
    {
    import flash.display.BitmapData;
    import flash.display.Sprite;
    import flash.display.StageAlign;
    import flash.display.StageScaleMode;
    import flash.events.Event;
    import flash.events.MouseEvent;
    import flash.geom.Vector3D;

    import away3d.cameras.Camera3D;
    import away3d.containers.ObjectContainer3D;
    import away3d.containers.Scene3D;
    import away3d.containers.View3D;
    import away3d.core.math.Plane3D;
    import away3d.entities.Mesh;
    import away3d.materials.ColorMaterial;
    import away3d.materials.TextureMaterial;
    import away3d.primitives.CubeGeometry;
    import away3d.primitives.PlaneGeometry;
    import away3d.primitives.WireframePlane;
    import away3d.textures.BitmapTexture;

    public class perceptron3D extends Sprite
    {

    private const NUM_DATA:int = 100;
    private const NUM_TRAIN:int = 20;

    private var _w:Vector.<Number>;//重みベクトル
    private var _points:Vector.<Vector.<Number>>;//学習用パターン
    private var _labels:Vector.<int>;//正解ラベル

    private var _cnt:int;//訓練回数
    private var _cursor:int;//学習用パターンを指定するカーソル

    private var _view:View3D;
    private var _cubes:Array;
    private var _scene:Scene3D;
    private var _camera:Camera3D;
    private var _container:ObjectContainer3D;
    private var _angle:Number = 0;
    private var _hyperPlane:Mesh;
    private var _floor:WireframePlane;

    public function perceptron3D()
    {
    stage.align = StageAlign.TOP_LEFT;
    stage.scaleMode = StageScaleMode.NO_SCALE;

    _scene = new Scene3D();
    _camera = new Camera3D();
    _view = new View3D();
    _view.scene = _scene;
    _view.camera = _camera;
    _view.width = 500;
    _view.height = 500;
    addChild(_view);

    _camera.y = 700;

    var geometry:PlaneGeometry = new PlaneGeometry(1200, 1200, 1, 1, true, true);
    var material:ColorMaterial = new ColorMaterial(0xFFFFFF,0.5);
    _hyperPlane = new Mesh(geometry, material);
    _hyperPlane.x = 0;
    _hyperPlane.y = 0;
    _hyperPlane.z = 0;

    _floor = new WireframePlane(3000,3000,10,10,0x222222);
    _floor.rotationZ = 90;

    init();
    stage.addEventListener(MouseEvent.CLICK,onClick);
    addEventListener(Event.ENTER_FRAME,onEnterFrame);
    }

    private function predict(x:Vector.<Number>):Number{
    var y:Number = 0;
    for(var i:int = 0 ; i < x.length;i++){
    y += _w[i] * x[i];
    }
    return y;
    }

    private function train(x:Vector.<Number>,t:int):void{
    var y:Number = predict(x);
    if(y * t < 0){
    for(var i:int = 0 ; i < x.length;i++){
    _w[i] += t * x[i];
    }
    }
    }

    private function init():void{

    for(var i:int = 0;i < _scene.numChildren ; i++){
    _scene.removeChildAt(i);
    }

    _scene.addChild(_hyperPlane);
    _scene.addChild(_floor);

    _cnt = 0;
    _cursor = 0;
    _w = new Vector.<Number>();
    _w[0] = 1;
    _w[1] = 1;
    _w[2] = 1;

    _points = new Vector.<Vector.<Number>>();
    _labels = new Vector.<int>();

    var randomX:Number = Math.random() * 5;
    var randomY:Number = Math.random() * 5;
    var randomZ:Number = Math.random() * 5;

    function divide(x:Number,y:Number,z:Number):Number{
    return randomX * x + randomY * y + randomZ * z;//真の分離平面
    }

    var geometry:CubeGeometry = new CubeGeometry(20, 20, 20, 1, 1, 1, false);

    for(i = 0 ; i < NUM_DATA ; i++){
    var p:Vector.<Number> = new Vector.<Number>();
    p[0] = Math.random() * 1000 - 500;
    p[1] = Math.random() * 1000 - 500;
    p[2] = Math.random() * 1000 - 500;
    _points.push(p);

    var bitmapData:BitmapData;
    if(divide(p[0],p[1],p[2]) > 0){
    _labels.push(1);
    bitmapData = new BitmapData(2, 2, false, 0xFF0000);
    }else{
    _labels.push(-1);
    bitmapData = new BitmapData(2, 2, false, 0x0000FF);
    }
    var texture:BitmapTexture = new BitmapTexture(bitmapData);
    var material:TextureMaterial = new TextureMaterial(texture);
    var cube:Mesh = new Mesh(geometry, material);
    cube.x = p[0];
    cube.y = p[1];
    cube.z = p[2];

    _scene.addChild(cube);

    }
    }

    private function onClick(e:MouseEvent):void{
    init();
    }

    private function onEnterFrame(e:Event):void{
    //trace(_w[0],_w[1],_w[2]);
    if(_cursor >= _points.length){
    _cnt++;
    if(_cnt >= NUM_TRAIN){
    trace("fin");
    //removeEventListener(Event.ENTER_FRAME,onEnterFrame);
    return;
    }
    _cursor = 0;
    }
    train(_points[_cursor],_labels[_cursor]);
    _cursor++;

    var target:Vector3D = new Vector3D(_w[0],_w[1],_w[2]);//法線ベクトルになる
    _hyperPlane.lookAt(target);
    _hyperPlane.rotate(new Vector3D(1,0,0),90);

    var radius:Number = 1500;
    _camera.x = Math.cos(_angle) * radius;
    _camera.z = Math.sin(_angle)* radius;
    _angle += 0.01;

    _camera.lookAt(new Vector3D(0, 0, 0));
    _view.render();

    }

    }
    }

    Posted by Takeya Hikage on 2014年01月28日
    Categories flash 機械学習