|
1 | 1 | #include <cassert> |
2 | 2 | #include <vector> |
3 | 3 | #include <optional> |
| 4 | +#include <algorithm> |
| 5 | +#include <cmath> |
| 6 | + |
4 | 7 |
|
5 | 8 | #include "LazyBlob.h" |
6 | 9 | #include "Iterations.h" |
7 | 10 | #include "Allocator.h" |
8 | 11 | #include "Blob.h" |
9 | 12 |
|
10 | 13 | #define MAX_DIMS_COUNT 4 |
| 14 | +#define EPS 1e-9 |
11 | 15 |
|
12 | 16 | const Shape& LazyBlob::shape() const { |
13 | 17 | if (shape_.has_value()) { |
@@ -568,3 +572,145 @@ std::ostream& operator<<(std::ostream& os, const LazyBlob &b) { |
568 | 572 | } |
569 | 573 | return os; |
570 | 574 | } |
| 575 | + |
| 576 | +class LazyBlobEntropy: public LazyBlob { |
| 577 | +public: |
| 578 | + const LazyBlob &a, &b; |
| 579 | + const int classCount; |
| 580 | + LazyBlobEntropy(const LazyBlob &a, const LazyBlob &b, int classCount): |
| 581 | + a(a), b(b), classCount(classCount) {}; |
| 582 | + |
| 583 | + void initShape() const final override { |
| 584 | + shape_ = b.shape(); |
| 585 | + } |
| 586 | + |
| 587 | + float operator() (std::size_t k, std::size_t l, std::size_t i, std::size_t j) const override { |
| 588 | + assert(b(k, l, i, j) < classCount); |
| 589 | + // WARNING: если проблемы, меняем на случай с EPS |
| 590 | + return std::log(a(k, l, i, (int) b(k, l, i, j))); |
| 591 | + // return std::log(a(k, l, i, (int) b(k, l, i, j)) + EPS); |
| 592 | + } |
| 593 | +}; |
| 594 | + |
| 595 | +class LazyBlobEntropyDerivative: public LazyBlob { |
| 596 | +public: |
| 597 | + const LazyBlob &a, &b; |
| 598 | + const int classCount; |
| 599 | + LazyBlobEntropyDerivative(const LazyBlob &a, const LazyBlob &b, int classCount): |
| 600 | + a(a), b(b), classCount(classCount) {}; |
| 601 | + |
| 602 | + void initShape() const final override { |
| 603 | + shape_ = a.shape(); |
| 604 | + } |
| 605 | + |
| 606 | + float operator() (std::size_t k, std::size_t l, std::size_t i, std::size_t j) const override { |
| 607 | + assert(b(k, l, i, j) < classCount); |
| 608 | + if (j != (int) b(k, 0, 0, 0)) { |
| 609 | + return 0; |
| 610 | + } |
| 611 | + // WARNING: если проблемы, меняем на случай с EPS |
| 612 | + return - 1.0f / (a(k, l, i, j)); |
| 613 | + // return - 1.0f / (a(k, l, i, j) + EPS); |
| 614 | + } |
| 615 | +}; |
| 616 | + |
| 617 | +const LazyBlob& LazyBlob::entropy(const LazyBlob& a, int classCount) const { |
| 618 | + assert(shape().cols() == classCount); |
| 619 | + assert(shape().dim4() == a.shape().dim4()); |
| 620 | + assert(shape().dim3() == 1); |
| 621 | + assert(shape().rows() == 1); |
| 622 | + assert(a.shape().dim3() == 1); |
| 623 | + assert(a.shape().rows() == 1); |
| 624 | + assert(a.shape().cols() == 1); |
| 625 | + |
| 626 | + void* location = Allocator::allocateBytes(sizeof(LazyBlobEntropy)); |
| 627 | + return *(new(location) LazyBlobEntropy(*this, a, classCount)); |
| 628 | +} |
| 629 | + |
| 630 | +const LazyBlob& LazyBlob::entropyDerivative(const LazyBlob& a, int classCount) const { |
| 631 | + assert(shape().cols() == classCount); |
| 632 | + assert(shape().dim4() == a.shape().dim4()); |
| 633 | + assert(shape().dim3() == 1); |
| 634 | + assert(shape().rows() == 1); |
| 635 | + assert(a.shape().dim3() == 1); |
| 636 | + assert(a.shape().rows() == 1); |
| 637 | + assert(a.shape().cols() == 1); |
| 638 | + void* location = Allocator::allocateBytes(sizeof(LazyBlobEntropyDerivative)); |
| 639 | + return *(new(location) LazyBlobEntropyDerivative(*this, a, classCount)); |
| 640 | +} |
| 641 | + |
| 642 | +class LazyBlobMaxPool: public LazyBlob { |
| 643 | +public: |
| 644 | + const LazyBlob &a; |
| 645 | + LazyBlobMaxPool(const LazyBlob &a): a(a) {}; |
| 646 | + |
| 647 | + void initShape() const final override { |
| 648 | + shape_ = { |
| 649 | + { |
| 650 | + a.shape().dim4(), a.shape().dim3(), a.shape().rows() / 2, a.shape().cols() / 2 |
| 651 | + }, |
| 652 | + a.shape().dimsCount |
| 653 | + }; |
| 654 | + } |
| 655 | + |
| 656 | + float operator() (std::size_t k, std::size_t l, std::size_t i, std::size_t j) const override { |
| 657 | + return std::max( |
| 658 | + std::max(a(k, l, i * 2, j * 2), a(k, l, i * 2 + 1, j * 2)), |
| 659 | + std::max(a(k, l, i * 2, j * 2 + 1), a(k, l, i * 2 + 1, j * 2 + 1)) |
| 660 | + ); |
| 661 | + } |
| 662 | +}; |
| 663 | + |
| 664 | +class LazyBlobMaxPoolDerivative: public LazyBlob { |
| 665 | +public: |
| 666 | + const LazyBlob &a, &b; |
| 667 | + LazyBlobMaxPoolDerivative(const LazyBlob &a, const LazyBlob& b): a(a), b(b) {}; |
| 668 | + |
| 669 | + void initShape() const final override { |
| 670 | + shape_ = a.shape(); |
| 671 | + } |
| 672 | + |
| 673 | + float operator() (std::size_t k, std::size_t l, std::size_t i, std::size_t j) const override { |
| 674 | + size_t start_i = (i / 2) * 2; |
| 675 | + size_t start_j = (j / 2) * 2; |
| 676 | + size_t indexOfMax_i = start_i; |
| 677 | + size_t indexOfMax_j = start_j; |
| 678 | + float max = a(k, l, indexOfMax_i, indexOfMax_j); |
| 679 | + if (max < a(k, l, start_i, start_j + 1)) { |
| 680 | + indexOfMax_i = start_i; |
| 681 | + indexOfMax_j = start_j + 1; |
| 682 | + max = a(k, l, indexOfMax_i, indexOfMax_j); |
| 683 | + } |
| 684 | + |
| 685 | + if (max < a(k, l, start_i + 1, start_j)) { |
| 686 | + indexOfMax_i = start_i + 1; |
| 687 | + indexOfMax_j = start_j; |
| 688 | + max = a(k, l, indexOfMax_i, indexOfMax_j); |
| 689 | + } |
| 690 | + |
| 691 | + if (max < a(k, l, start_i + 1, start_j + 1)) { |
| 692 | + indexOfMax_i = start_i + 1; |
| 693 | + indexOfMax_j = start_j + 1; |
| 694 | + max = a(k, l, indexOfMax_i, indexOfMax_j); |
| 695 | + } |
| 696 | + |
| 697 | + if (indexOfMax_i == i && indexOfMax_j == j) |
| 698 | + return b(k, l, i / 2, j / 2); |
| 699 | + |
| 700 | + return 0.0f; |
| 701 | + } |
| 702 | +}; |
| 703 | + |
| 704 | +const LazyBlob& LazyBlob::maxPool() const { |
| 705 | + assert(shape().cols() % 2 == 0); |
| 706 | + assert(shape().rows() % 2 == 0); |
| 707 | + void* location = Allocator::allocateBytes(sizeof(LazyBlobEntropyDerivative)); |
| 708 | + return *(new(location) LazyBlobMaxPool(*this)); |
| 709 | +} |
| 710 | + |
| 711 | +const LazyBlob& LazyBlob::maxPoolDerivative(const LazyBlob& b) const { |
| 712 | + assert(shape().cols() % 2 == 0); |
| 713 | + assert(shape().rows() % 2 == 0); |
| 714 | + void* location = Allocator::allocateBytes(sizeof(LazyBlobMaxPoolDerivative)); |
| 715 | + return *(new(location) LazyBlobMaxPoolDerivative(*this, b)); |
| 716 | +} |
0 commit comments